Skip to content

[release/2.7] Optimize Flex-Attention occupancy for head_dim=128#2957

Merged
jataylo merged 1 commit intorelease/2.7from
tomjen12/release/2.7-flex-attn-warp-optimization
Feb 11, 2026
Merged

[release/2.7] Optimize Flex-Attention occupancy for head_dim=128#2957
jataylo merged 1 commit intorelease/2.7from
tomjen12/release/2.7-flex-attn-warp-optimization

Conversation

@tomjen12
Copy link

@tomjen12 tomjen12 commented Feb 4, 2026

[ROCm] Optimize Flex-Attention occupancy for head_dim=128

Summary

This PR adjusts n_warps from 8 to 4 for head_dim=128 configurations on ROCm for gfx942.

Performance Impact (320 Valid Cases)

The following table shows the Geometric Mean (Geomean) speedups compared to the current n_warps=8 baseline:

Attention Pattern Test Count Fwd Speedup Bwd Speedup
alibi 32 1.07x 1.60x
causal 32 1.08x 1.44x
noop 32 1.07x 1.38x
prefix_lm 112 1.09x 1.40x
sliding_window 112 1.07x 1.27x
Overall (Geomean) 320 1.08x 1.41x

Benchmark Coverage:

  • Batch Size: [1, 2, 4, 8]
  • Heads: [16, 32]
  • Sequence Length: [512, 1024, 2048, 4096]
  • Masking: window_size and prefix_lm ∈ [128, 256, 512, 1024, 2048].
  • Note: Redundant cases (e.g., window_size > seq_len) were excluded.

Adjust n_warps from 8 to 4 for head_dim=128 configurations to improve performance stability across different attention patterns.

- Forward speedup: ~1.07x geomean uplift.
- Backward speedup: 1.27x to 1.60x geomean uplift.
- Validated with a filtered sweep of 320 unique cases.
@tomjen12 tomjen12 requested review from jataylo and jeffdaily February 4, 2026 03:25
@tomjen12 tomjen12 changed the title [ROCm] Optimize Flex-Attention occupancy for head_dim=128 [release/2.7] Optimize Flex-Attention occupancy for head_dim=128 Feb 4, 2026
@rocm-repo-management-api
Copy link

rocm-repo-management-api bot commented Feb 4, 2026

Jenkins build for 4b7cecf5fbae5ab9ef60a3b7ae25abfa63074041 commit finished as FAILURE
Links: Pipeline Overview / Build artifacts / Test Results

@jataylo jataylo merged commit 6b53931 into release/2.7 Feb 11, 2026
0 of 2 checks passed
@jataylo jataylo deleted the tomjen12/release/2.7-flex-attn-warp-optimization branch February 11, 2026 12:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants