feat(mi300x:ck-grouped-gemm): M-aware tile selection for BF16 and FP8#271
Open
ChengYao-amd wants to merge 1 commit intomainfrom
Open
feat(mi300x:ck-grouped-gemm): M-aware tile selection for BF16 and FP8#271ChengYao-amd wants to merge 1 commit intomainfrom
ChengYao-amd wants to merge 1 commit intomainfrom
Conversation
Add CU utilization-aware tile downgrade for CK grouped GEMM. When total_tiles < NUM_CU, automatically selects 128x128 tile instead of 256x256 to improve occupancy for small-M MoE decode. - New 128x128 tile config and extern template declarations - M-aware dispatch logic in instance factory (BF16/FP16 + FP8) - BF16 DSv2-Lite B=2 M=512: +97% (130 -> 257 TFLOPS) - BF16 Qwen3-30B B=8 M=512: +7% - FP8: similar gains for small-M MoE decode
46e3679 to
3215d50
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
128x128x64for BF16/FP16 and128x128x128for FP8group_num * ceil(M/256) * ceil(N/256) < NUM_CU, downgrade from 256-size to 128-size tilesChanges
csrc/kernels/grouped_gemm/ck_grouped_gemm_kernel_config.hCKGroupedGemmTileCfg_128x128x64_32x32x16_2x2x1tile config typecsrc/kernels/grouped_gemm/ck_grouped_gemm_kernel_instance_factory.cucsrc/kernels/grouped_gemm/ck_grouped_gemm_kernel_template.hcsrc/kernels/grouped_gemm/instantiations/ck_grouped_gemm_kernel_bf16.cucsrc/kernels/grouped_gemm/instantiations/ck_grouped_gemm_kernel_fp16.cuOptimization Details
Resource comparison (per CU):
Smaller tiles = more tiles = better CU occupancy for small-M MoE decode.
Performance (MI300X)
BF16 Small M (128x128 tile triggered)
BF16 Large M (no regression)
CK Optimized vs Triton (BF16)
CK remains the best backend for small-M Grouped GEMM scenarios.
Correctness
Test Plan
pip3 install --no-build-isolation -e . -v(C++ rebuild required)pytest tests/pytorch/ops/test_grouped_gemm.py -xpytest tests/pytorch/ops/test_grouped_gemm_fp8.py -xpython3 benchmark/ops/bench_ck_m_aware.pypython3 benchmark/ops/bench_ck_fp8_m_aware.py