avoid warp diverge in warp specialized kernel#5830
Conversation
|
Review updated until commit 63ce41a Description
|
| Relevant files | |||
|---|---|---|---|
| Enhancement |
| ||
| Tests |
|
PR Reviewer Guide
Here are some key observations to aid the review process:
| 🧪 PR contains tests |
| ⚡ Recommended focus areas for review |
Error Message Clarity
|
Test failures
-
(High, 186)
CUDA driver/runtime mismatch breaking nvFuser tests on dlcluster_h100Test Name H100 Source ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/1024_3_1_0 ❌ Link ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_1_0_0 ❌ Link ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_1_1_1 ❌ Link ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_2_0_1 ❌ Link ArgsortParameterizedWithBlockAndBatch.SharedMemoryRequirement/512_3_0_0 ❌ Link BlockSizeAndItemsPerThread/ArgSortComprehensiveTest.ComprehensiveValidation/BlockSize128_ItemsPerThread4 ❌ Link BlockSizeAndItemsPerThread/ArgSortComprehensiveTest.ComprehensiveValidation/BlockSize32_ItemsPerThread5 ❌ Link ClusterReductionTest.SimpleFusionAllReduce/cluster_10_dtype_float ❌ Link ClusterReductionTest.SimpleFusionNotAllReduce/cluster_10_dtype___bfloat ❌ Link ClusterReductionTest.SimpleFusionNotAllReduce/cluster_15_dtype___bfloat ❌ Link ... with 176 more test failures omitted. Check internal logs. -
(High, 16)
CUDA driver too old on dlcluster_h100 causing device_count_ensure_non_zero failures in RNGTestTest Name H100 Source .thunder.tests.opinfos ❌ .thunder.tests.test_apex_cross_entropy_executor ❌ .thunder.tests.test_auto_register_torchops ❌ .thunder.tests.test_cudnn_executor ❌ .thunder.tests.test_einops ❌ .thunder.tests.test_grad ❌ .thunder.tests.test_nvfuser ❌ .thunder.tests.test_ops ❌ .thunder.tests.test_sdpaex_executor ❌ .thunder.tests.test_torch_compile_executor ❌ ... with 6 more test failures omitted. Check internal logs. -
(Medium, 1)
Thunder NVFuser nanoGPT autograd mismatch in test_networksTest Name GB200 Source thunder.tests.test_networks.test_nanogpt_complete_autograd_nvfuser_cuda_thunder.dtypes.float32 ❌
Greptile SummaryAdded compile-time validation to prevent warp divergence in warp-specialized kernels when using TIDx specialization. The change enforces that both the original and padded Key changes:
Technical context: Confidence Score: 5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as Test Suite
participant PDM as ParallelDimensionMap
participant Validator as Warp Divergence Validator
Test->>PDM: adjustMappingsForWarpSpecialization(TIDx)
PDM->>PDM: Calculate other_active_threads (bdimy * bdimz)
PDM->>PDM: Calculate padding: 128 / other_active_threads
PDM->>PDM: Calculate after_pad = original_tidx + padding
alt ws_pt == TIDx
PDM->>Validator: Check original_tidx % 32 == 0
alt original_tidx not multiple of 32
Validator-->>Test: ERROR: bdimx must be multiple of 32
else original_tidx is valid
PDM->>Validator: Check after_pad % 32 == 0
alt after_pad not multiple of 32
Validator-->>Test: ERROR: padded bdimx must be multiple of 32
else after_pad is valid
PDM->>PDM: Apply padding to dimension map
PDM-->>Test: Success
end
end
else ws_pt != TIDx
PDM->>PDM: Apply padding to dimension map
PDM-->>Test: Success
end
|
|
!test |
| if (ws_pt == ParallelType::TIDx && | ||
| getTmaPadThreads(ws_pt, bdim) % 32 != 0) { |
There was a problem hiding this comment.
style: condition checks padding amount but validation checks total (original + padding). works for current test cases where original_tidx is always a multiple of 32, but would fail if test added case like dim3(96, 8, 1) where original=96 (divisible by 32), pad=16 (not divisible), but after_pad=112 (not divisible by 32)
| if (ws_pt == ParallelType::TIDx && | |
| getTmaPadThreads(ws_pt, bdim) % 32 != 0) { | |
| if (ws_pt == ParallelType::TIDx && | |
| (bdim.x + getTmaPadThreads(ws_pt, bdim)) % 32 != 0) { |
is the test suite intended to only cover cases where original bdimx is a multiple of 32?
|
!test |
No description provided.