Skip to content

remove input cached smem bank conflicts in transpose scheduler#5930

Draft
liqiangxl wants to merge 6 commits intollu/transpose_tile_sizefrom
llu/transpose_bank_conflict
Draft

remove input cached smem bank conflicts in transpose scheduler#5930
liqiangxl wants to merge 6 commits intollu/transpose_tile_sizefrom
llu/transpose_bank_conflict

Conversation

@liqiangxl
Copy link
Collaborator

No description provided.

@liqiangxl
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented Feb 6, 2026

Review updated until commit a9de12f

Description

  • Filter reduction domains in transpose domain mapping to improve scheduler accuracy

  • Add shared memory swizzle optimization to reduce bank conflicts in transpose scheduler

  • Move tile size doubling logic and add debug output for memory bandwidth analysis

  • Remove invalid test cases that were failing due to scheduler limitations

  • Update function signatures to use raw pointers for better performance

Changes walkthrough

Relevant files
Enhancement
domain_map.cpp
Filter reduction domains in transpose domain mapping         

csrc/scheduler/tools/domain_map.cpp

  • Filter out reduction domains before comparing allocation domains
  • Apply filtering to both reference loops in hasAtLeastTwoValidGroups
  • Update broadcast detection to use filtered domains
  • +9/-3     
    transpose.cpp
    Add shared memory swizzle optimization and improve bank conflict
    handling

    csrc/scheduler/transpose.cpp

  • Add shared memory swizzle scheduling for cached input tensors to
    reduce bank conflicts
  • Move tile size doubling logic and add debug output for memory
    bandwidth analysis
  • Update hasSmallTransposeDimensions to use raw pointer parameter
  • Add conditions to disable swizzle for non-square tiles and cached
    outputs
  • Update exclusion sets to include swizzled tensors in transformation
    propagation
  • +90/-33 
    Tests
    test_gpu3.cpp
    Remove invalid transpose test case                                             

    tests/cpp/test_gpu3.cpp

    +0/-22   
    test_transpose.cpp
    Fix scheduler type expectation in reduction test                 

    tests/cpp/test_transpose.cpp

  • Update expected scheduler type from Transpose to PointWise in
    reduction test
  • +1/-1     

    PR Reviewer Guide

    Here are some key observations to aid the review process:

    🧪 PR contains tests
    ⚡ Recommended focus areas for review
    Debug Output

    Multiple std::cout statements were added for debugging (lines 747-752). These should be removed or replaced with proper logging mechanisms before merging.

    std::cout << "total_input_bits_per_elem: " << total_input_bits_per_elem
              << std::endl;
    std::cout << "num_elems_per_tile: " << num_elems_per_tile << std::endl;
    std::cout << "max_blocks_per_sm: " << max_blocks_per_sm << std::endl;
    std::cout << "bits_in_flight_per_sm: " << bits_in_flight_per_sm << std::endl;
    std::cout << "required_bits_per_sm: " << required_bits_per_sm << std::endl;
    Test Coverage

    Two test cases were removed without clear explanation. The removed tests (FusionScheduleTransposeRepro1_CUDA and BroadcastingRNGSmemNonSquareTile) may have been testing important scenarios that should either be updated to work with the new swizzle logic or replaced with equivalent tests.

    // clang-format off
    Swizzle Logic Validation

    The new shared memory swizzle logic (lines 1352-1373) needs validation to ensure it correctly handles bank conflicts across different tile sizes and vectorization factors. The conditions for disabling swizzle (lines 976-979) should be thoroughly tested.

    if (use_smem_swizzle) {
      for (auto tv : smem_cached_input_tvs) {
        std::cout << "scheduling smem_cached_tv: " << tv->toString() << std::endl;
        int64_t pos = tv->nDims() - 2;
        bool is_group2 = group2_and_cached_inputs.count(tv) > 0;
        int64_t tile2_factor =
            is_group2 ? tparams->vectorize_factor2 : tparams->vectorize_factor1;
        int64_t tile1_factor =
            tparams->tile_size1 * tile2_factor / tparams->tile_size2;
        // [BIDx, UnSwitch, tile1, tile2]
        tv->split(pos + 1, tile2_factor);
        tv->split(pos, tile1_factor);
        tv->swizzle(SwizzleType::XOR, pos, pos + 2);
        tv->merge(pos);
        tv->merge(pos);
        tv->split(pos, tparams->getThreadsPerBlock());
        tv->axis(pos)->parallelize(ParallelType::Unroll);
        tv->axis(pos + 1)->parallelize(ParallelType::TIDx);
        tv->axis(pos + 2)->parallelize(ParallelType::Vectorize);
        std::cout << "scheduled smem_cached_tv: " << tv->toString() << std::endl;
      }
    }

    Test failures

    • (Medium, 11) NVFuser assertion failure: unsupported swizzle of broadcast axes in multiple nvfuser & ThunderFX tests

      Test Name A100 GB200 H100 Source
      PersistentBufferTest.SmemPersistentNotSupportedIn3DReduction Link
      RNGTest.BroadcastingRNGSmem Link
      tests.python.direct.test_repro.test_domain_map_hang[nvfuser_direct_test=eager]
      tests.python.direct.test_repro.test_domain_map_hang[nvfuser_direct_test=lru_cache]
      tests.python.test_moe.test_llama4_moe_thunderfx
    • (Medium, 3) ThunderFX higher-order inplace alias update shape mismatch in test_update_aliases

      Test Name A100 GB200 H100 Source
      thunder.tests.test_update_aliases.test_higher_order_inplace_alias_update_nvfuser_cuda_thunder.dtypes.float32

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl liqiangxl force-pushed the llu/transpose_bank_conflict branch from 5bbb6fa to 6e53109 Compare February 8, 2026 16:37
    @liqiangxl
    Copy link
    Collaborator Author

    !test

    @liqiangxl
    Copy link
    Collaborator Author

    !test

    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

    Labels

    None yet

    Projects

    None yet

    Development

    Successfully merging this pull request may close these issues.

    1 participant