Skip to content

Manual scheduling test for outer TMA reduction#5926

Open
tbqh wants to merge 1 commit intomainfrom
tbqh/outer_reduction_tma_manual
Open

Manual scheduling test for outer TMA reduction#5926
tbqh wants to merge 1 commit intomainfrom
tbqh/outer_reduction_tma_manual

Conversation

@tbqh
Copy link
Collaborator

@tbqh tbqh commented Feb 6, 2026

Add manually scheduled test for outer-reduction with 2D TMA.

Dumps for outer_16384_iter_16384:

TMA cuda_kernel
__global__ void nvfuser_none_f0_c0_r0_g0(Tensor<float, 2, 2> T0, const __grid_constant__ TensorMap var0, Tensor<float, 1, 1> T1, Tensor<float, 1, 1> T6, Tensor<int64_t, 1, 1> T7) {
  alignas(128) extern __shared__ char array[];
  void* shared_mem = array;
  const unsigned smem_offset = alignBufferSize(32 * 16 * 1 * sizeof(float) * 4, 128);
  nvfuser_index_t i1;
  i1 = ceilDiv((ceilDiv(T0.logical_size[0LL], 128)), 8);
  const TensorMap* ptr2;
  ptr2 = &var0;
  nvfuser_index_t i3;
  i3 = 128 * ((nvfuser_index_t)blockIdx.x);
  int i4;
  i4 = __to_int32(i3);
  nvfuser_index_t i5;
  i5 = (128 * i1) * ((nvfuser_index_t)blockIdx.y);
  float* T2 = reinterpret_cast<float*>(array + smem_offset + 0);
  uint32_t i6;
  i6 = toSmem(T2);
  nvfuser_index_t i7;
  i7 = ((nvfuser_index_t)threadIdx.x) + (128 * ((nvfuser_index_t)threadIdx.y));
  nvfuser_index_t i8;
  i8 = ((nvfuser_index_t)threadIdx.x) + i3;
  nvfuser_index_t i9;
  i9 = ((-T0.logical_size[0LL]) + ((nvfuser_index_t)threadIdx.y)) + i5;
  nvfuser_index_t i10;
  i10 = ((-T0.logical_size[1LL]) + ((nvfuser_index_t)threadIdx.x)) + i3;
  bool b11;
  b11 = (((nvfuser_index_t)blockIdx.y) == (((nvfuser_index_t)gridDim.y) + -1)) && (((nvfuser_index_t)threadIdx.y) == 0);
  Array<float, 4, 1> T3;
  #pragma unroll
  for (int i = 0; i < 4; ++i) {
    T3[(0) + i] = 0.000000000e+00f;
  }
  // Allocate global tensor T6
  // Allocate global tensor T7
  Array<float, 4, 4> T4;
  #pragma unroll
  for(nvfuser_index_t i12 = 0; i12 < 4; ++i12) {
    T4[i12] = 0.000000000e+00f;
  }
  #pragma unroll 1
  for(nvfuser_index_t i13 = 0; i13 < i1; ++i13) {
    nvfuser_index_t i14;
    i14 = 128 * i13;
    Array<int, 2, 1> a15;
    a15 = Array<int, 2, 1>{i4, __to_int32((i5 + i14))};
    nvfuser_index_t i16;
    i16 = i9 + i14;
    uint64_t* T5 = reinterpret_cast<uint64_t*>(array + smem_offset + 65536);
    mbarrier::init(toSmem(T5), 1U);
    __syncthreads();
    if ((Hopper::electSync(4294967295U) && (((nvfuser_index_t)threadIdx.y) == 0ULL))) {
      uint64_t i17;
      i17 = mbarrier::arriveExpectTX(toSmem(T5), 65536U);
      Hopper::cpAsyncBulkTensorTileG2S((Hopper::CpAsyncBulkTensorTileG2SIndex<2>{ ptr2, a15, toSmem(T5) }), i6);
      mbarrier::wait(toSmem(T5), i17);
    }
    __syncthreads();
    mbarrier::inval(toSmem(T5));
    #pragma unroll
    for(nvfuser_index_t i18 = 0; i18 < 8; ++i18) {
      nvfuser_index_t i19;
      i19 = i7 + (2048 * i18);
      bool b20;
      b20 = i16 < (-(16 * i18));
      #pragma unroll
      for(nvfuser_index_t i12 = 0; i12 < 4; ++i12) {
        nvfuser_index_t i21;
        i21 = 32 * i12;
        if ((b20 && (i10 < (-i21)))) {
          T4[i12]
            = T4[i12]
            + T2[(i19 + i21)];
        }
      }
    }
  }
  reduction::iterGroupedGridReduce<false, true, false, false, true, false, false, true, 4>(
    *(float(*)[4])(&T3[0]),
    *(float(*)[4])(&T4[0]),
    [](float &a, float b) { a = a + b; },
    &T6[0],
    &T7[0],
    static_cast<float*>(shared_mem),
    true,
    true,
    float(0.000000000e+00f),
    DefaultBlockDim());
  #pragma unroll
  for(nvfuser_index_t i22 = 0; i22 < 4; ++i22) {
    nvfuser_index_t i23;
    i23 = 32 * i22;
    if ((b11 && (i10 < (-i23)))) {
      T1[(i8 + i23)]
         = T3[i22];
    }
  }
}
non-TMA cuda_kernel
// Codegen generated code
__global__ void nvfuser_reduction_f0_c1_r0_g0(Tensor<float, 2, 2> T0, Tensor<float, 1, 1> T1, Tensor<float, 1, 1> T5, Tensor<int64_t, 1, 1> T6) {
  alignas(128) extern __shared__ char array[];
  void* shared_mem = array;
  NVFUSER_DEFINE_MAGIC_ZERO;
  nvfuser_index_t i0;
  i0 = ceilDiv((ceilDiv((ceilDiv(T0.logical_size[0LL], ((nvfuser_index_t)blockDim.y))), 8)), ((nvfuser_index_t)gridDim.y));
  nvfuser_index_t i1;
  i1 = ((nvfuser_index_t)blockDim.y) * 8;
  nvfuser_index_t i2;
  i2 = i1 * T0.logical_size[1LL];
  nvfuser_index_t i3;
  i3 = 4 * ((nvfuser_index_t)threadIdx.x);
  nvfuser_index_t i4;
  i4 = (4 * ((nvfuser_index_t)blockDim.x)) * ((nvfuser_index_t)blockIdx.x);
  nvfuser_index_t i5;
  i5 = (((T0.logical_size[1LL] * ((nvfuser_index_t)threadIdx.y)) + ((i2 * i0) * ((nvfuser_index_t)blockIdx.y))) + i3) + i4;
  nvfuser_index_t i6;
  i6 = ((nvfuser_index_t)blockDim.y) * T0.logical_size[1LL];
  nvfuser_index_t i7;
  i7 = i3 + i4;
  bool b8;
  b8 = ((3 + i3) + i4) < T0.logical_size[1LL];
  nvfuser_index_t i9;
  i9 = ((((nvfuser_index_t)blockDim.y) * 7) + ((nvfuser_index_t)threadIdx.y)) + ((i1 * ((nvfuser_index_t)blockIdx.y)) * i0);
  nvfuser_index_t i10;
  i10 = ((-T0.logical_size[0LL]) + ((nvfuser_index_t)threadIdx.y)) + ((i1 * i0) * ((nvfuser_index_t)blockIdx.y));
  bool b11;
  b11 = (((nvfuser_index_t)blockIdx.y) == (((nvfuser_index_t)gridDim.y) + -1)) && (((nvfuser_index_t)threadIdx.y) == 0);
  // Allocate global tensor T5
  // Allocate global tensor T6
  Array<float, 4, 4> T4;
  #pragma unroll
  for(nvfuser_index_t i12 = 0; i12 < 4; ++i12) {
    T4[i12] = 0.000000000e+00f;
  }
  NVFUSER_UPDATE_MAGIC_ZERO;
  #pragma unroll 1
  for(nvfuser_index_t i13 = 0; i13 < i0; ++i13) {
    nvfuser_index_t i14;
    i14 = i5 + (i2 * i13);
    nvfuser_index_t i15;
    i15 = i1 * i13;
    nvfuser_index_t i16;
    i16 = i10 + i15;
    if ((b8 && ((i9 + i15) < T0.logical_size[0LL]))) {
      Array<float, 32, 4> T2;
      #pragma unroll
      for(nvfuser_index_t i17 = 0; i17 < 8; ++i17) {
        T2.set(float(0.000000000e+00f));
      }
      NVFUSER_UPDATE_MAGIC_ZERO;
      #pragma unroll
      for(nvfuser_index_t i17 = 0; i17 < 8; ++i17) {
        loadGlobalToLocal<float, /*vec_size=*/4, /*is_volatile=*/false, CacheOp::Streaming>(&T2[(4 * i17)],  &T0[(i14 + (i6 * (i17 + nvfuser_zero)))]);
      }
      NVFUSER_UPDATE_MAGIC_ZERO;
      #pragma unroll
      for(nvfuser_index_t i12 = 0; i12 < 4; ++i12) {
        #pragma unroll
        for(nvfuser_index_t i18 = 0; i18 < 8; ++i18) {
          T4[i12]
            = T4[i12]
            + T2[(i12 + (4 * i18))];
        }
      }
      NVFUSER_UPDATE_MAGIC_ZERO;
    } else {
      Array<float, 32, 4> T2;
      #pragma unroll
      for(nvfuser_index_t i17 = 0; i17 < 8; ++i17) {
        T2.set(float(0.000000000e+00f));
      }
      NVFUSER_UPDATE_MAGIC_ZERO;
      #pragma unroll
      for(nvfuser_index_t i17 = 0; i17 < 8; ++i17) {
        nvfuser_index_t i19;
        i19 = i17 + nvfuser_zero;
        if ((b8 && (i16 < (-(((nvfuser_index_t)blockDim.y) * i19))))) {
          loadGlobalToLocal<float, /*vec_size=*/4, /*is_volatile=*/false, CacheOp::Streaming>(&T2[(4 * i17)],  &T0[(i14 + (i6 * i19))]);
        }
      }
      NVFUSER_UPDATE_MAGIC_ZERO;
      #pragma unroll
      for(nvfuser_index_t i12 = 0; i12 < 4; ++i12) {
        #pragma unroll
        for(nvfuser_index_t i18 = 0; i18 < 8; ++i18) {
          T4[i12]
            = T4[i12]
            + T2[(i12 + (4 * i18))];
        }
      }
      NVFUSER_UPDATE_MAGIC_ZERO;
    }
  }
  if (b8) {
    Array<float, 4, 4> T3;
    #pragma unroll
    for (int i = 0; i < 4; ++i) {
      T3[(0) + i] = 0.000000000e+00f;
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    reduction::iterGroupedGridReduce<false, true, false, false, true, false, false, false, 4>(
      *(float(*)[4])(&T3[0]),
      *(float(*)[4])(&T4[0]),
      [](float &a, float b) { a = a + b; },
      &T5[0],
      &T6[0],
      static_cast<float*>(shared_mem),
      true,
      true,
      float(0.000000000e+00f),
      DefaultBlockDim());
    if (b11) {
      loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T1[i7], &T3[0]);
    }
  } else {
    Array<float, 4, 4> T3;
    #pragma unroll
    for (int i = 0; i < 4; ++i) {
      T3[(0) + i] = 0.000000000e+00f;
    }
    NVFUSER_UPDATE_MAGIC_ZERO;
    reduction::iterGroupedGridReduce<false, true, false, false, true, false, false, false, 4>(
      *(float(*)[4])(&T3[0]),
      *(float(*)[4])(&T4[0]),
      [](float &a, float b) { a = a + b; },
      &T5[0],
      &T6[0],
      static_cast<float*>(shared_mem),
      true,
      true,
      float(0.000000000e+00f),
      DefaultBlockDim());
    if ((b11 && b8)) {
      loadLocalToGlobal<float, /*vec_size=*/4, /*is_volatile=*/false>( &T1[i7], &T3[0]);
    }
  }
}
TMA fusion_ir
Inputs:
  T0_g_float[iS15{8}, iS16{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}, iS14{128}, iS17{( ceilDiv(i2, 128) )}, iS18{128}]
Outputs:
  T1_g_float[iblockIdx.x25{( ceilDiv(i2, 128) )}, iS50{4}, ithreadIdx.x51{32}] ca_pos( 3 )

%kernel {
T2_s_float[iblockIdx.y11{8}, iS12{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}, iB8{128}, iblockIdx.x9{( ceilDiv(i2, 128) )}, iB10{128}] ca_pos( 2 )
   = CpAsyncBulkTensorTile( T0_g_float[iS15{8}, iS16{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}, iS14{128}, iS17{( ceilDiv(i2, 128) )}, iS18{128}] )
T4_l_float[iblockIdx.y35{8}rf, rS36{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}rf, rS37{8}rf, ithreadIdx.y38{16}rf, iblockIdx.x39{( ceilDiv(i2, 128) )}, iS41{4}, ithreadIdx.x42{32}] ca_pos( 1 ) produce_pos( 2 )
   = reduction( T2_s_float[iblockIdx.y11{8}, iS12{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}, iB8{128}, iblockIdx.x9{( ceilDiv(i2, 128) )}, iB10{128}] ca_pos( 2 ), op = add, initial value = float(0), allreduce = false )
T3_l_float[rblockIdx.y43{8}, rthreadIdx.y44{16}, iblockIdx.x46{( ceilDiv(i2, 128) )}, iG48{4}, ithreadIdx.x49{32}] produce_pos( 1 )
   = reduction( T4_l_float[iblockIdx.y35{8}rf, rS36{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}rf, rS37{8}rf, ithreadIdx.y38{16}rf, iblockIdx.x39{( ceilDiv(i2, 128) )}, iS41{4}, ithreadIdx.x42{32}] ca_pos( 1 ) produce_pos( 2 ), op = add, initial value = float(0), allreduce = false )
T1_g_float[iblockIdx.x25{( ceilDiv(i2, 128) )}, iS50{4}, ithreadIdx.x51{32}] ca_pos( 3 )
   = Set( T3_l_float[rblockIdx.y43{8}, rthreadIdx.y44{16}, iblockIdx.x46{( ceilDiv(i2, 128) )}, iG48{4}, ithreadIdx.x49{32}] produce_pos( 1 ), cache_op=Streaming )

TransformPrinter : 
T0_g_float[iS15{8}, iS16{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}, iS14{128}, iS17{( ceilDiv(i2, 128) )}, iS18{128}]
  logical domain: (iS0{i0}, iS1{i2})
  contiguity: t t
    Split: iS0{i0} by factor 128 -> iS13{( ceilDiv(i0, 128) )}, iS14{128}
    Outer split: iS13{( ceilDiv(i0, 128) )} by factor 8 -> iS15{8}, iS16{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}
    Split: iS1{i2} by factor 128 -> iS17{( ceilDiv(i2, 128) )}, iS18{128}
  loop domain: (iS15{8}, iS16{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}, iS14{128}, iS17{( ceilDiv(i2, 128) )}, iS18{128})
T2_s_float[iblockIdx.y11{8}, iS12{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}, iB8{128}, iblockIdx.x9{( ceilDiv(i2, 128) )}, iB10{128}] ca_pos( 2 )
  logical domain: (iS4{i0}, iS5{i2})
  allocation domain: (iblockIdx.y11{8}, iS12{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}, iB8{128}, iblockIdx.x9{( ceilDiv(i2, 128) )}, iB10{128})
  contiguity: t t t t t
    Split: iS4{i0} by factor 128 -> iS7{( ceilDiv(i0, 128) )}, iB8{128}
    Outer split: iS7{( ceilDiv(i0, 128) )} by factor 8 -> iblockIdx.y11{8}, iS12{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}
    Split: iS5{i2} by factor 128 -> iblockIdx.x9{( ceilDiv(i2, 128) )}, iB10{128}
  loop domain: (iblockIdx.y11{8}, iS12{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}, iB8{128}, iblockIdx.x9{( ceilDiv(i2, 128) )}, iB10{128})
T4_l_float[iblockIdx.y35{8}rf, rS36{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}rf, rS37{8}rf, ithreadIdx.y38{16}rf, iblockIdx.x39{( ceilDiv(i2, 128) )}, iS41{4}, ithreadIdx.x42{32}] ca_pos( 1 ) produce_pos( 2 )
  root domain: (rS31{i0}rf, iS32{i2})
    Split: rS31{i0}rf by factor 128 -> rS33{( ceilDiv(i0, 128) )}rf, rS34{128}rf
    Outer split: rS33{( ceilDiv(i0, 128) )}rf by factor 8 -> iblockIdx.y35{8}rf, rS36{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}rf
    Split: rS34{128}rf by factor 16 -> rS37{8}rf, ithreadIdx.y38{16}rf
  logical domain: (iblockIdx.y35{8}rf, rS36{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}rf, rS37{8}rf, ithreadIdx.y38{16}rf, iS32{i2})
  contiguity: t n n t t
    Split: iS32{i2} by factor 128 -> iblockIdx.x39{( ceilDiv(i2, 128) )}, iS40{128}
    Split: rS31{i0}rf by factor 128 -> rS33{( ceilDiv(i0, 128) )}rf, rS34{128}rf
    Outer split: rS33{( ceilDiv(i0, 128) )}rf by factor 8 -> iblockIdx.y35{8}rf, rS36{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}rf
    Split: rS34{128}rf by factor 16 -> rS37{8}rf, ithreadIdx.y38{16}rf
    Split: iS40{128} by factor 32 -> iS41{4}, ithreadIdx.x42{32}
  loop domain: (iblockIdx.y35{8}rf, rS36{( ceilDiv(( ceilDiv(i0, 128) ), 8) )}rf, rS37{8}rf, ithreadIdx.y38{16}rf, iblockIdx.x39{( ceilDiv(i2, 128) )}, iS41{4}, ithreadIdx.x42{32})
T3_l_float[rblockIdx.y43{8}, rthreadIdx.y44{16}, iblockIdx.x46{( ceilDiv(i2, 128) )}, iG48{4}, ithreadIdx.x49{32}] produce_pos( 1 )
  logical domain: (rblockIdx.y43{8}, rthreadIdx.y44{16}, iS45{i2})
  contiguity: n n t
    Split: iS45{i2} by factor 128 -> iblockIdx.x46{( ceilDiv(i2, 128) )}, iS47{128}
    Split: iS47{128} by factor 32 -> iG48{4}, ithreadIdx.x49{32}
  loop domain: (rblockIdx.y43{8}, rthreadIdx.y44{16}, iblockIdx.x46{( ceilDiv(i2, 128) )}, iG48{4}, ithreadIdx.x49{32})
T1_g_float[iblockIdx.x25{( ceilDiv(i2, 128) )}, iS50{4}, ithreadIdx.x51{32}] ca_pos( 3 )
  logical domain: (iS6{i2})
  contiguity: t
    Split: iS6{i2} by factor 128 -> iblockIdx.x25{( ceilDiv(i2, 128) )}, iS26{128}
    Split: iS26{128} by factor 32 -> iS50{4}, ithreadIdx.x51{32}
  loop domain: (iblockIdx.x25{( ceilDiv(i2, 128) )}, iS50{4}, ithreadIdx.x51{32})
} // %kernel
non-TMA fusion_ir
Inputs:
  T0_g_float[iS66{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, iS65{blockDim.x}, iS75{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}, iS67{1}, iS63{4}, iS74{gridDim.y}, iS69{blockDim.y}, iS73{1}, iS71{8}]
Outputs:
  T1_g_float[iblockIdx.x80{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x79{blockDim.x}, iUS81{1}, iV77{4}] ca_pos( 3 ) produce_pos( 3 )

%kernel {
T2_l_float[iblockIdx.x52{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x51{blockDim.x}, iS61{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}, iUS53{1}, iV49{4}, iblockIdx.y60{gridDim.y}, ithreadIdx.y55{blockDim.y}, iUS59{1}, iUR57{8}] ca_pos( 4 )
   = Set( T0_g_float[iS66{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, iS65{blockDim.x}, iS75{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}, iS67{1}, iS63{4}, iS74{gridDim.y}, iS69{blockDim.y}, iS73{1}, iS71{8}], cache_op=Streaming )
T4_l_float[iblockIdx.x29{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x28{blockDim.x}, rS38{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}rf, iUS30{1}, iS26{4}, iblockIdx.y37{gridDim.y}rf, ithreadIdx.y32{blockDim.y}rf, rUS36{1}rf, rS34{8}rf] ca_pos( 2 ) produce_pos( 4 )
   = reduction( T2_l_float[iblockIdx.x52{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x51{blockDim.x}, iS61{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}, iUS53{1}, iV49{4}, iblockIdx.y60{gridDim.y}, ithreadIdx.y55{blockDim.y}, iUS59{1}, iUR57{8}] ca_pos( 4 ), op = add, initial value = float(0), allreduce = false )
T3_l_float[iblockIdx.x46{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x45{blockDim.x}, iUS47{1}, iG43{4}, rblockIdx.y39{gridDim.y}, rthreadIdx.y40{blockDim.y}] ca_pos( 3 ) produce_pos( 2 )
   = reduction( T4_l_float[iblockIdx.x29{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x28{blockDim.x}, rS38{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}rf, iUS30{1}, iS26{4}, iblockIdx.y37{gridDim.y}rf, ithreadIdx.y32{blockDim.y}rf, rUS36{1}rf, rS34{8}rf] ca_pos( 2 ) produce_pos( 4 ), op = add, initial value = float(0), allreduce = false )
T1_g_float[iblockIdx.x80{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x79{blockDim.x}, iUS81{1}, iV77{4}] ca_pos( 3 ) produce_pos( 3 )
   = Set( T3_l_float[iblockIdx.x46{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x45{blockDim.x}, iUS47{1}, iG43{4}, rblockIdx.y39{gridDim.y}, rthreadIdx.y40{blockDim.y}] ca_pos( 3 ) produce_pos( 2 ), cache_op=Streaming )

TransformPrinter : 
T0_g_float[iS66{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, iS65{blockDim.x}, iS75{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}, iS67{1}, iS63{4}, iS74{gridDim.y}, iS69{blockDim.y}, iS73{1}, iS71{8}]
  logical domain: (iS0{i0}, iS1{i2})
  contiguity: t t
    Split: iS1{i2} by factor 4 -> iS62{( ceilDiv(i2, 4) )}, iS63{4}
    Split: iS0{i0} by factor blockDim.y -> iS68{( ceilDiv(i0, blockDim.y) )}, iS69{blockDim.y}
    Split: iS62{( ceilDiv(i2, 4) )} by factor blockDim.x -> iS64{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, iS65{blockDim.x}
    Split: iS68{( ceilDiv(i0, blockDim.y) )} by factor 8 -> iS70{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )}, iS71{8}
    Split: iS64{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )} by factor 1 -> iS66{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, iS67{1}
    Split: iS70{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )} by factor 1 -> iS72{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )}, iS73{1}
    Outer split: iS72{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )} by factor gridDim.y -> iS74{gridDim.y}, iS75{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}
  loop domain: (iS66{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, iS65{blockDim.x}, iS75{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}, iS67{1}, iS63{4}, iS74{gridDim.y}, iS69{blockDim.y}, iS73{1}, iS71{8})
T2_l_float[iblockIdx.x52{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x51{blockDim.x}, iS61{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}, iUS53{1}, iV49{4}, iblockIdx.y60{gridDim.y}, ithreadIdx.y55{blockDim.y}, iUS59{1}, iUR57{8}] ca_pos( 4 )
  logical domain: (iS6{i0}, iS7{i2})
  contiguity: t t
    Split: iS7{i2} by factor 4 -> iS48{( ceilDiv(i2, 4) )}, iV49{4}
    Split: iS6{i0} by factor blockDim.y -> iS54{( ceilDiv(i0, blockDim.y) )}, ithreadIdx.y55{blockDim.y}
    Split: iS48{( ceilDiv(i2, 4) )} by factor blockDim.x -> iS50{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x51{blockDim.x}
    Split: iS54{( ceilDiv(i0, blockDim.y) )} by factor 8 -> iS56{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )}, iUR57{8}
    Split: iS50{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )} by factor 1 -> iblockIdx.x52{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, iUS53{1}
    Split: iS56{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )} by factor 1 -> iS58{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )}, iUS59{1}
    Outer split: iS58{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )} by factor gridDim.y -> iblockIdx.y60{gridDim.y}, iS61{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}
  loop domain: (iblockIdx.x52{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x51{blockDim.x}, iS61{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}, iUS53{1}, iV49{4}, iblockIdx.y60{gridDim.y}, ithreadIdx.y55{blockDim.y}, iUS59{1}, iUR57{8})
T4_l_float[iblockIdx.x29{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x28{blockDim.x}, rS38{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}rf, iUS30{1}, iS26{4}, iblockIdx.y37{gridDim.y}rf, ithreadIdx.y32{blockDim.y}rf, rUS36{1}rf, rS34{8}rf] ca_pos( 2 ) produce_pos( 4 )
  root domain: (rS23{i0}rf, iS24{i2})
    Split: rS23{i0}rf by factor blockDim.y -> rS31{( ceilDiv(i0, blockDim.y) )}rf, ithreadIdx.y32{blockDim.y}rf
    Split: rS31{( ceilDiv(i0, blockDim.y) )}rf by factor 8 -> rS33{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )}rf, rS34{8}rf
    Split: rS33{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )}rf by factor 1 -> rS35{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )}rf, rUS36{1}rf
    Outer split: rS35{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )}rf by factor gridDim.y -> iblockIdx.y37{gridDim.y}rf, rS38{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}rf
  logical domain: (iblockIdx.y37{gridDim.y}rf, rS38{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}rf, rUS36{1}rf, rS34{8}rf, ithreadIdx.y32{blockDim.y}rf, iS24{i2})
  contiguity: t n n n t t
    Split: iS24{i2} by factor 4 -> iS25{( ceilDiv(i2, 4) )}, iS26{4}
    Split: rS23{i0}rf by factor blockDim.y -> rS31{( ceilDiv(i0, blockDim.y) )}rf, ithreadIdx.y32{blockDim.y}rf
    Split: iS25{( ceilDiv(i2, 4) )} by factor blockDim.x -> iS27{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x28{blockDim.x}
    Split: rS31{( ceilDiv(i0, blockDim.y) )}rf by factor 8 -> rS33{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )}rf, rS34{8}rf
    Split: iS27{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )} by factor 1 -> iblockIdx.x29{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, iUS30{1}
    Split: rS33{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )}rf by factor 1 -> rS35{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )}rf, rUS36{1}rf
    Outer split: rS35{( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) )}rf by factor gridDim.y -> iblockIdx.y37{gridDim.y}rf, rS38{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}rf
  loop domain: (iblockIdx.x29{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x28{blockDim.x}, rS38{( ceilDiv(( ceilDiv(( ceilDiv(i0, blockDim.y) ), 8) ), gridDim.y) )}rf, iUS30{1}, iS26{4}, iblockIdx.y37{gridDim.y}rf, ithreadIdx.y32{blockDim.y}rf, rUS36{1}rf, rS34{8}rf)
T3_l_float[iblockIdx.x46{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x45{blockDim.x}, iUS47{1}, iG43{4}, rblockIdx.y39{gridDim.y}, rthreadIdx.y40{blockDim.y}] ca_pos( 3 ) produce_pos( 2 )
  logical domain: (rblockIdx.y39{gridDim.y}, rthreadIdx.y40{blockDim.y}, iS41{i2})
  contiguity: n n t
    Split: iS41{i2} by factor 4 -> iS42{( ceilDiv(i2, 4) )}, iG43{4}
    Split: iS42{( ceilDiv(i2, 4) )} by factor blockDim.x -> iS44{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x45{blockDim.x}
    Split: iS44{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )} by factor 1 -> iblockIdx.x46{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, iUS47{1}
  loop domain: (iblockIdx.x46{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x45{blockDim.x}, iUS47{1}, iG43{4}, rblockIdx.y39{gridDim.y}, rthreadIdx.y40{blockDim.y})
T1_g_float[iblockIdx.x80{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x79{blockDim.x}, iUS81{1}, iV77{4}] ca_pos( 3 ) produce_pos( 3 )
  logical domain: (iS8{i2})
  contiguity: t
    Split: iS8{i2} by factor 4 -> iS76{( ceilDiv(i2, 4) )}, iV77{4}
    Split: iS76{( ceilDiv(i2, 4) )} by factor blockDim.x -> iS78{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x79{blockDim.x}
    Split: iS78{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )} by factor 1 -> iblockIdx.x80{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, iUS81{1}
  loop domain: (iblockIdx.x80{( ceilDiv(( ceilDiv(i2, 4) ), blockDim.x) )}, ithreadIdx.x79{blockDim.x}, iUS81{1}, iV77{4})
} // %kernel

@tbqh tbqh requested a review from liqiangxl February 6, 2026 10:06
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR adds a comprehensive manually-scheduled test for outer reduction operations using 2D TMA (Tensor Memory Accelerator), complementing the existing inner reduction TMA tests.

Key changes:

  • Implemented TmaOuterReductionManualTest test class with parameterized testing for various outer_size and iter_size combinations (256 to 65536, multiplied by 4)
  • Added 10-phase manual scheduling strategy for outer reductions that reduces along axis 0 (the outer dimension)
  • Implemented proper TMA alignment checks (16-byte requirement) with appropriate test skipping logic
  • Used iter-grouped grid reduction (iterGroupedGridReduce) for better efficiency and reduced bank conflicts compared to separate grid reduce calls
  • Included extensive inline comments documenting each scheduling phase for maintainability
  • Test validates correctness by comparing manually scheduled kernel outputs against an unscheduled reference fusion

The test follows established patterns from TmaInnerReductionManualTest and properly handles edge cases through GTEST_SKIP when dimensions are too small for meaningful 2D TMA testing.

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • The implementation follows well-established patterns from the existing TmaInnerReductionManualTest, includes comprehensive documentation via inline comments, handles edge cases properly with GTEST_SKIP, and only adds new test code without modifying existing functionality. The test validates correctness against an unscheduled reference fusion.
  • No files require special attention

Important Files Changed

Filename Overview
tests/cpp/test_reduction.cpp Added comprehensive manually-scheduled test for outer reduction with 2D TMA, following established patterns from similar inner reduction tests

Sequence Diagram

sequenceDiagram
    participant Test as Test Framework
    participant Fusion as Fusion Builder
    participant Schedule as Scheduler
    participant TMA as TMA Cache
    participant Executor as KernelExecutor
    participant Validation as Validator

    Test->>Fusion: Create fusion with 2D tensor input [R, I]
    Fusion->>Fusion: Add sum reduction on axis 0 (outer reduction)
    Fusion->>Fusion: Save fusion_copy for validation
    
    Test->>Schedule: Create TMA cache in shared memory
    Schedule->>TMA: cacheAfter with CpAsyncBulkTensorTile
    TMA->>TMA: Set memory type to Shared
    
    Test->>Schedule: Apply TMA-level tiling
    Schedule->>Schedule: Split reduction dimension by tma_tile_r
    Schedule->>Schedule: Split iteration dimension by tma_tile_i
    Schedule->>Schedule: Split for grid parallelization (grdim)
    
    Test->>Schedule: Propagate TMA tiling to all tensors
    Schedule->>Schedule: MaxLogicalDomainInfoSpanningTree traverse
    
    Test->>Schedule: Parallelize TMA tensor
    Schedule->>Schedule: Set BIDy, Serial, Bulk, BIDx parallelization
    Schedule->>Schedule: Set allocation domain for shared memory
    
    Test->>Schedule: Sub-split TMA tiles into thread dimensions
    Schedule->>Schedule: Split tma_tile_i by bdimx (32)
    Schedule->>Schedule: Split tma_tile_r by bdimy (16)
    
    Test->>Schedule: Parallelize reduction tensor
    Schedule->>Schedule: Apply rFactor for grid reduction
    Schedule->>Schedule: Propagate thread-level splits to non-TMA TVs
    
    Test->>Schedule: Set up iter-grouped reduction
    Schedule->>Schedule: propagateParallelization with use_iter_grouped_reduction=true
    Schedule->>Schedule: Apply inlineMost optimization
    
    Test->>Executor: Compile fusion with input tensor
    Executor->>Executor: Generate CUDA kernel code
    Test->>Executor: Run kernel with test input
    Executor-->>Test: Return computed outputs
    
    Test->>Validation: testValidate with fusion_copy
    Validation->>Validation: Compare scheduled vs unscheduled results
    Validation-->>Test: Validation success
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

@github-actions
Copy link

github-actions bot commented Feb 6, 2026

Description

  • Add comprehensive manual scheduling test for 2D TMA outer reductions

  • Test validates TMA alignment requirements (16-byte alignment on both dimensions)

  • Implements manual fusion scheduling with TMA-specific optimizations including grid reduction

  • Parameterized test suite covering tensor sizes from 256 to 65536 elements

Changes walkthrough

Relevant files
Tests
test_reduction.cpp
Add TMA outer reduction manual test                                           

tests/cpp/test_reduction.cpp

  • Add TmaOuterReductionManualTest class with parameterized tests for 2D
    TMA outer reductions
  • Implement manual fusion scheduling with TMA tensor loading using
    CpAsyncBulkTensorTile
  • Add grid reduction with iter-grouped reduction for efficient outer
    reduction computation
  • Include TMA alignment validation and skip conditions for proper test
    execution
  • Parameterize test with tensor sizes from 256 to 65536 elements on both
    dimensions
  • +204/-0 

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Complex Manual Scheduling

    The test implements a sophisticated 10-phase manual scheduling strategy for TMA outer reductions. While the code appears structurally sound, the complexity makes it prone to subtle bugs. Key areas to verify: (1) TMA tile dimension calculations ensure proper alignment, (2) The rFactor axes selection logic correctly identifies non-thread-parallelized reduction axes, (3) The iter-grouped reduction propagation works correctly with the TMA tensor hierarchy.

    // ========== Phase 1: Create TMA cache in shared memory ==========
    auto tv0smem = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
    tv0smem->setMemoryType(MemoryType::Shared);
    
    // Cache before the output for the grid reduction
    auto redu_tv = tv1->cacheBefore();
    
    // ========== Phase 2: Schedule TMA tensor with TMA-level tiling ==========
    // [R, I] -> [R/tma_tile_r, tma_tile_r, I]
    tv0smem->split(0, tma_tile_r);
    
    // -> [R/tma_tile_r, tma_tile_r, I/tma_tile_i, tma_tile_i]
    tv0smem->split(2, tma_tile_i);
    
    // Split outer reduction for grid parallelization
    // -> [grdim, R', tma_tile_r, I/tma_tile_i, tma_tile_i]
    //       0      1    2           3              4
    tv0smem->split(0, grdim, false);
    
    // ========== Phase 3: Propagate TMA tiling to all tensors ==========
    TransformPropagator propagator(tv0smem);
    MaxLogicalDomainInfoSpanningTree(tv0smem).traverse(&propagator);
    
    // ========== Phase 4: Parallelize and finalize TMA tensor ==========
    tv0smem->axis(0)->parallelize(ParallelType::BIDy);
    tv0smem->axis(1)->parallelize(ParallelType::Serial);
    tv0smem->axis(2)->parallelize(ParallelType::Bulk); // reduction tile
    tv0smem->axis(3)->parallelize(ParallelType::BIDx);
    tv0smem->axis(4)->parallelize(ParallelType::Bulk); // iteration tile
    
    // Set allocation domain for proper shared memory layout
    tv0smem->setAllocationDomain(tv0smem->getLoopDomain(), true);
    
    // ========== Phase 5: Sub-split TMA tiles into thread dims ==========
    // Split tma_tile_i into [iter_unroll, bdimx]
    redu_tv->split(4, bdimx);
    
    // Split tma_tile_r into [redu_unroll, bdimy]
    redu_tv->split(2, bdimy);
    // Now: [grdim, R', redu_unroll, bdimy, I/tma_tile_i, iter_unroll, bdimx]
    //       0      1   2            3      4              5            6
    
    // ========== Phase 6: Parallelize reduction tensor ==========
    redu_tv->axis(0)->parallelize(ParallelType::BIDy);
    redu_tv->axis(1)->parallelize(ParallelType::Serial);
    redu_tv->axis(2)->parallelize(ParallelType::Unroll); // redu_unroll
    redu_tv->axis(3)->parallelize(ParallelType::TIDy); // bdimy
    redu_tv->axis(4)->parallelize(ParallelType::BIDx);
    // Use Vectorize so it gets converted to Group for iterGroupedGridReduce
    redu_tv->axis(5)->parallelize(ParallelType::Vectorize);
    redu_tv->axis(6)->parallelize(ParallelType::TIDx); // bdimx
    
    // ========== Phase 7: rFactor for grid reduction ==========
    // The reduction axes that are not thread-parallelized need rFactor
    std::vector<int64_t> rfactor_axes;
    for (int64_t i = 0; i < redu_tv->nDims(); i++) {
      if (redu_tv->axis(i)->isReduction() && !redu_tv->axis(i)->isThread()) {
        rfactor_axes.push_back(i);
      }
    }
    
    TensorView* ref_tv = redu_tv;
    if (!rfactor_axes.empty()) {
      ref_tv = redu_tv->rFactor(rfactor_axes);
    }
    
    // ========== Phase 8: Propagate thread-level splits to non-TMA TVs ==========
    std::vector<TensorView*> non_tma_tvs =
        ir_utils::allTvsExcept(&fusion, {tv0smem});
    TransformPropagator non_tma_propagator(ref_tv);
    SetSelector non_tma_selector({non_tma_tvs.begin(), non_tma_tvs.end()});
    MaxLogicalDomainInfoSpanningTree(ref_tv, &non_tma_selector)
        .traverse(&non_tma_propagator);
    
    // ========== Phase 9: Parallelize with iter-grouped reduction ==========
    // For outer reduction, use iterGroupedGridReduce which is more efficient
    // and has fewer bank conflicts than separate gridReduce calls
    const bool use_iter_grouped_reduction = true; // outer reduction + cross_block
    std::vector<TensorView*> reduction_tvs = {redu_tv};
    
    if (ref_tv != redu_tv) {
      reduction_scheduler_utils::propagateRFactor(ref_tv, redu_tv, reduction_tvs);
      non_tma_tvs = ir_utils::allTvsExcept(&fusion, {tv0smem});
    }
    
    // Use propagateParallelization to set up grouped reduction
    reduction_scheduler_utils::propagateParallelization(
        redu_tv,
        ref_tv,
        /*is_unroll_or_vectorization=*/true,
        use_iter_grouped_reduction,
        reduction_tvs,
        /*unroll_vectorizable_cached_tvs=*/{},
        /*selected_tvs=*/non_tma_tvs);
    TMA Alignment Validation

    The test includes proper 16-byte alignment checks for TMA operations, but verify that the alignment validation logic correctly handles all edge cases and that the skipped test cases don't represent important coverage gaps.

    // TMA requires 16-byte alignment on both dimensions
    if ((outer_size * dtype_bytes) % 16 != 0) {
      GTEST_SKIP() << "Outer dimension bytes not divisible by 16, can't use TMA";
      return;
    }
    if ((iter_size * dtype_bytes) % 16 != 0) {
      GTEST_SKIP() << "Iter dimension bytes not divisible by 16, can't use TMA";
      return;
    }
    Performance Impact

    This is a manual scheduling test that likely targets specific performance optimizations. Verify that the test demonstrates the expected performance benefits and doesn't introduce regressions in the broader test suite.

    // Test 2D TMA for outer reductions
    using TmaOuterReductionManualTestParams =
        std::tuple<int64_t, int64_t>; // <outer_size, iter_size>
    
    class TmaOuterReductionManualTest
        : public NVFuserFixtureParamTest<TmaOuterReductionManualTestParams> {
     protected:
      void SetUp() override {
        NVFuserFixtureParamTest<TmaOuterReductionManualTestParams>::SetUp();
        NVFUSER_TEST_CUDA_ARCH_GUARD(9, 0);
        enable_options_guard_ = std::make_unique<EnableOptionsGuard>();
        EnableOptionsGuard::getCurOptions().set(EnableOption::TmaReduction);
      }
    
     private:
      std::unique_ptr<EnableOptionsGuard> enable_options_guard_;
    };
    TEST_P(TmaOuterReductionManualTest, Basic) {
      auto dtype = DataType::Float;
      int64_t dtype_bytes = dataTypeSizeByte(dtype);
      auto [outer_size, iter_size] = GetParam();
    
      // TMA requires 16-byte alignment on both dimensions
      if ((outer_size * dtype_bytes) % 16 != 0) {
        GTEST_SKIP() << "Outer dimension bytes not divisible by 16, can't use TMA";
        return;
      }
      if ((iter_size * dtype_bytes) % 16 != 0) {
        GTEST_SKIP() << "Iter dimension bytes not divisible by 16, can't use TMA";
        return;
      }
    
      std::vector<int64_t> shape = {outer_size, iter_size};
    
      auto fusion_ptr = std::make_unique<Fusion>();
      FusionGuard fg(fusion_ptr.get());
      Fusion& fusion = *fusion_ptr;
    
      // Input: [R, I] where R is reduction dim (axis 0), I is iteration dim (axis
      // 1)
      auto tv0 = makeContigTensor(2);
      fusion.addInput(tv0);
      auto tv1 = sum(tv0, {0}); // reduce along axis 0 (outer reduction)
      fusion.addOutput(tv1);
      auto fusion_copy = fusion;
    
      auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0);
      at::Tensor t0 = at::randn(shape, options);
    
      const int64_t bdimx = 32;
      const int64_t bdimy = 16;
      const int64_t iter_unroll_factor = 4;
      const int64_t redu_unroll_factor = 8;
    
      int64_t grdim = std::max<int64_t>(
          1, std::min<int64_t>(8, scheduler_utils::lastPow2(outer_size / 256)));
    
      // TMA tile dimensions = thread dims * unroll factors
      const int64_t tma_tile_i = bdimx * iter_unroll_factor;
      const int64_t tma_tile_r = bdimy * redu_unroll_factor;
    
      // Skip if outer_size is too small for meaningful 2D TMA test
      if (outer_size < tma_tile_r) {
        GTEST_SKIP() << "outer_size " << outer_size
                     << " is smaller than tma_tile_r " << tma_tile_r;
        return;
      }
    
      // Skip if iter_size is too small for meaningful 2D TMA test
      if (iter_size < tma_tile_i) {
        GTEST_SKIP() << "iter_size " << iter_size << " is smaller than tma_tile_i "
                     << tma_tile_i;
        return;
      }
    
      // ========== Phase 1: Create TMA cache in shared memory ==========
      auto tv0smem = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile);
      tv0smem->setMemoryType(MemoryType::Shared);
    
      // Cache before the output for the grid reduction
      auto redu_tv = tv1->cacheBefore();
    
      // ========== Phase 2: Schedule TMA tensor with TMA-level tiling ==========
      // [R, I] -> [R/tma_tile_r, tma_tile_r, I]
      tv0smem->split(0, tma_tile_r);
    
      // -> [R/tma_tile_r, tma_tile_r, I/tma_tile_i, tma_tile_i]
      tv0smem->split(2, tma_tile_i);
    
      // Split outer reduction for grid parallelization
      // -> [grdim, R', tma_tile_r, I/tma_tile_i, tma_tile_i]
      //       0      1    2           3              4
      tv0smem->split(0, grdim, false);
    
      // ========== Phase 3: Propagate TMA tiling to all tensors ==========
      TransformPropagator propagator(tv0smem);
      MaxLogicalDomainInfoSpanningTree(tv0smem).traverse(&propagator);
    
      // ========== Phase 4: Parallelize and finalize TMA tensor ==========
      tv0smem->axis(0)->parallelize(ParallelType::BIDy);
      tv0smem->axis(1)->parallelize(ParallelType::Serial);
      tv0smem->axis(2)->parallelize(ParallelType::Bulk); // reduction tile
      tv0smem->axis(3)->parallelize(ParallelType::BIDx);
      tv0smem->axis(4)->parallelize(ParallelType::Bulk); // iteration tile
    
      // Set allocation domain for proper shared memory layout
      tv0smem->setAllocationDomain(tv0smem->getLoopDomain(), true);
    
      // ========== Phase 5: Sub-split TMA tiles into thread dims ==========
      // Split tma_tile_i into [iter_unroll, bdimx]
      redu_tv->split(4, bdimx);
    
      // Split tma_tile_r into [redu_unroll, bdimy]
      redu_tv->split(2, bdimy);
      // Now: [grdim, R', redu_unroll, bdimy, I/tma_tile_i, iter_unroll, bdimx]
      //       0      1   2            3      4              5            6
    
      // ========== Phase 6: Parallelize reduction tensor ==========
      redu_tv->axis(0)->parallelize(ParallelType::BIDy);
      redu_tv->axis(1)->parallelize(ParallelType::Serial);
      redu_tv->axis(2)->parallelize(ParallelType::Unroll); // redu_unroll
      redu_tv->axis(3)->parallelize(ParallelType::TIDy); // bdimy
      redu_tv->axis(4)->parallelize(ParallelType::BIDx);
      // Use Vectorize so it gets converted to Group for iterGroupedGridReduce
      redu_tv->axis(5)->parallelize(ParallelType::Vectorize);
      redu_tv->axis(6)->parallelize(ParallelType::TIDx); // bdimx
    
      // ========== Phase 7: rFactor for grid reduction ==========
      // The reduction axes that are not thread-parallelized need rFactor
      std::vector<int64_t> rfactor_axes;
      for (int64_t i = 0; i < redu_tv->nDims(); i++) {
        if (redu_tv->axis(i)->isReduction() && !redu_tv->axis(i)->isThread()) {
          rfactor_axes.push_back(i);
        }
      }
    
      TensorView* ref_tv = redu_tv;
      if (!rfactor_axes.empty()) {
        ref_tv = redu_tv->rFactor(rfactor_axes);
      }
    
      // ========== Phase 8: Propagate thread-level splits to non-TMA TVs ==========
      std::vector<TensorView*> non_tma_tvs =
          ir_utils::allTvsExcept(&fusion, {tv0smem});
      TransformPropagator non_tma_propagator(ref_tv);
      SetSelector non_tma_selector({non_tma_tvs.begin(), non_tma_tvs.end()});
      MaxLogicalDomainInfoSpanningTree(ref_tv, &non_tma_selector)
          .traverse(&non_tma_propagator);
    
      // ========== Phase 9: Parallelize with iter-grouped reduction ==========
      // For outer reduction, use iterGroupedGridReduce which is more efficient
      // and has fewer bank conflicts than separate gridReduce calls
      const bool use_iter_grouped_reduction = true; // outer reduction + cross_block
      std::vector<TensorView*> reduction_tvs = {redu_tv};
    
      if (ref_tv != redu_tv) {
        reduction_scheduler_utils::propagateRFactor(ref_tv, redu_tv, reduction_tvs);
        non_tma_tvs = ir_utils::allTvsExcept(&fusion, {tv0smem});
      }
    
      // Use propagateParallelization to set up grouped reduction
      reduction_scheduler_utils::propagateParallelization(
          redu_tv,
          ref_tv,
          /*is_unroll_or_vectorization=*/true,
          use_iter_grouped_reduction,
          reduction_tvs,
          /*unroll_vectorizable_cached_tvs=*/{},
          /*selected_tvs=*/non_tma_tvs);
    
      // ========== Phase 10: Inline ==========
      inlineMost();
    
      KernelExecutor ke;
      ke.compile(&fusion, {t0});
      auto cg_outputs = ke.run({t0});
    
      testValidate(&fusion_copy, cg_outputs, {t0}, __LINE__, __FILE__, "");
    }
    
    INSTANTIATE_TEST_SUITE_P(
        ,
        TmaOuterReductionManualTest,
        testing::Combine(
            testing::ValuesIn([] { // outer_size
              std::vector<int64_t> vals;
              for (int64_t v = 256; v <= 65536; v *= 4) {
                vals.push_back(v);
              }
              return vals;
            }()),
            testing::ValuesIn([] { // iter_size
              std::vector<int64_t> vals;
              for (int64_t v = 256; v <= 65536; v *= 4) {
                vals.push_back(v);
              }
              return vals;
            }())),
        ([](const testing::TestParamInfo<TmaOuterReductionManualTestParams>& info) {
          auto [outer_size, iter_size] = info.param;
          return "outer_" + std::to_string(outer_size) + "_iter_" +
              std::to_string(iter_size);
        }));

    @liqiangxl
    Copy link
    Collaborator

    Is the generated kernel same to what we discussed offline? You can attached the generated code & fusion_ir to this PR.

    @tbqh
    Copy link
    Collaborator Author

    tbqh commented Feb 6, 2026

    @liqiangxl I added cuda_kernel and fusion_ir dumps to the PR description.

    The schedule is very similar. For 16384x16384, both use grid reduction and use reduction::iterGroupedGridReduce. Non-TMA is also unswitching, I am not sure that is needed for 2D TMA, which I thought will load zero for out-of-bounds reads (also divisibility is not a concern like it was for 1D TMA with inner reduction).

    To keep things simple, I don't do block-reduce vs grid-reduce comparison. Just have a few options for BIDy count and that's all.


    // ========== Phase 5: Sub-split TMA tiles into thread dims ==========
    // Split tma_tile_i into [iter_unroll, bdimx]
    redu_tv->split(4, bdimx);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Iter domain is the inner most domain, we should spilt it by vectorization factor first, then bdimx, this allows vectorized write of reduction result to gmem.

    tv0smem->axis(4)->parallelize(ParallelType::Bulk); // iteration tile

    // Set allocation domain for proper shared memory layout
    tv0smem->setAllocationDomain(tv0smem->getLoopDomain(), true);
    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Do we need this allocation domain set?

    reduction_tvs,
    /*unroll_vectorizable_cached_tvs=*/{},
    /*selected_tvs=*/non_tma_tvs);

    Copy link
    Collaborator

    Choose a reason for hiding this comment

    The reason will be displayed to describe this comment to others. Learn more.

    Need vectorize output tensor

    @liqiangxl
    Copy link
    Collaborator

    @liqiangxl I added cuda_kernel and fusion_ir dumps to the PR description.

    The schedule is very similar. For 16384x16384, both use grid reduction and use reduction::iterGroupedGridReduce. Non-TMA is also unswitching, I am not sure that is needed for 2D TMA, which I thought will load zero for out-of-bounds reads (also divisibility is not a concern like it was for 1D TMA with inner reduction).

    To keep things simple, I don't do block-reduce vs grid-reduce comparison. Just have a few options for BIDy count and that's all.

    Thanks. Can you add block reduction kernel and ir? unswitch is a performance optimization, we can check that later.

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    2 participants