Skip to content

feat(mi300x:ck-grouped-gemm): M-aware tile selection for BF16 and FP8#271

Open
ChengYao-amd wants to merge 1 commit intomainfrom
dev/yaoc/ck-grouped-gemm-m-aware
Open

feat(mi300x:ck-grouped-gemm): M-aware tile selection for BF16 and FP8#271
ChengYao-amd wants to merge 1 commit intomainfrom
dev/yaoc/ck-grouped-gemm-m-aware

Conversation

@ChengYao-amd
Copy link
Copy Markdown
Contributor

Summary

  • Add CU utilization-aware tile selection for CK Grouped GEMM across BF16/FP16 (Phase 3) and FP8 RowColQuant/TensorQuant/ABQuantGrouped (Round 4)
  • New tile config 128x128x64 for BF16/FP16 and 128x128x128 for FP8
  • When group_num * ceil(M/256) * ceil(N/256) < NUM_CU, downgrade from 256-size to 128-size tiles
  • Supports both GFX942 (MI300X, 304 CUs) and GFX950 (MI355X, 256 CUs)

Changes

File Change
csrc/kernels/grouped_gemm/ck_grouped_gemm_kernel_config.h Add CKGroupedGemmTileCfg_128x128x64_32x32x16_2x2x1 tile config type
csrc/kernels/grouped_gemm/ck_grouped_gemm_kernel_instance_factory.cu Add M-aware dispatch for BF16/FP16 + FP8 RowColQuant/TensorQuant paths
csrc/kernels/grouped_gemm/ck_grouped_gemm_kernel_template.h Add extern template declaration for 128x128 config
csrc/kernels/grouped_gemm/instantiations/ck_grouped_gemm_kernel_bf16.cu Add 128x128 instantiation
csrc/kernels/grouped_gemm/instantiations/ck_grouped_gemm_kernel_fp16.cu Add 128x128 instantiation

Optimization Details

Resource comparison (per CU):

Attribute 256x256x64 128x128x64 Change
LDS Usage 64 KB 32 KB -50%
Computation / Block 256x256x64 128x128x64 -75%
Tiles / Problem baseline 4x baseline +300%

Smaller tiles = more tiles = better CU occupancy for small-M MoE decode.

Performance (MI300X)

BF16 Small M (128x128 tile triggered)

Configuration Baseline (TFLOPS) Optimized (TFLOPS) Improvement
DSv2-Lite-GateUP B=2 M=512 130.4 257.1 +97.2%
DSv2-Lite-GateUP B=2 M=1024 242.6 278.1 +14.6%
Qwen3-30B-GateUP B=8 M=512 418.2 447.4 +7.0%
DSv3-GateUP B=8 M=512 458.1 475.0 +3.7%
MoE-1T-GateUP B=7 M=512 414.6 434.2 +4.7%

BF16 Large M (no regression)

Configuration Baseline (TFLOPS) Optimized (TFLOPS) Change
DSv3-GateUP B=8 M=4096 547.7 559.4 +2.1%
DSv3-GateUP B=8 M=16384 539.2 557.5 +3.4%
Kimi-K2-GateUP B=12 M=8192 518.1 540.3 +4.3%

CK Optimized vs Triton (BF16)

Scenario CK (TFLOPS) Triton (TFLOPS) CK Advantage
DSv2-Lite B=2 M=512 253.6 159.8 +58.7%
DSv3 B=8 M=512 473.2 461.4 +2.6%
Kimi-K2 B=12 M=512 391.3 381.0 +2.7%

CK remains the best backend for small-M Grouped GEMM scenarios.

Correctness

  • BF16/FP16/FP8, NN/NT/TN layouts, Variable-K, zero-length groups
  • 10,754 BF16 tests passed + 1,344 FP8 tests passed, 0 failures

Test Plan

  • pip3 install --no-build-isolation -e . -v (C++ rebuild required)
  • pytest tests/pytorch/ops/test_grouped_gemm.py -x
  • pytest tests/pytorch/ops/test_grouped_gemm_fp8.py -x
  • python3 benchmark/ops/bench_ck_m_aware.py
  • python3 benchmark/ops/bench_ck_fp8_m_aware.py

Add CU utilization-aware tile downgrade for CK grouped GEMM.
When total_tiles < NUM_CU, automatically selects 128x128 tile
instead of 256x256 to improve occupancy for small-M MoE decode.

- New 128x128 tile config and extern template declarations
- M-aware dispatch logic in instance factory (BF16/FP16 + FP8)
- BF16 DSv2-Lite B=2 M=512: +97% (130 -> 257 TFLOPS)
- BF16 Qwen3-30B B=8 M=512: +7%
- FP8: similar gains for small-M MoE decode
@ChengYao-amd ChengYao-amd force-pushed the dev/yaoc/ck-grouped-gemm-m-aware branch from 46e3679 to 3215d50 Compare April 8, 2026 06:10
@ChengYao-amd ChengYao-amd changed the title feat(ck-grouped-gemm): M-aware tile selection for BF16 and FP8 feat(mi300x:ck-grouped-gemm): M-aware tile selection for BF16 and FP8 Apr 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant