Skip to content

[No Merge][WIP] feat: primus-turbo attn add sbhd format support#650

Draft
RuibinCheung wants to merge 2 commits intomainfrom
dev/zhangrb/add_sbhd_format_support
Draft

[No Merge][WIP] feat: primus-turbo attn add sbhd format support#650
RuibinCheung wants to merge 2 commits intomainfrom
dev/zhangrb/add_sbhd_format_support

Conversation

@RuibinCheung
Copy link
Copy Markdown
Contributor

@RuibinCheung RuibinCheung commented Apr 8, 2026

  • Primus Turbo attention add sbhd format support
image

This PR is experimental. Please not merge!

Copilot AI review requested due to automatic review settings April 8, 2026 11:30
@RuibinCheung RuibinCheung changed the title [No Merge][WIP] feat: add sbhd format support [No Merge][WIP] feat: primus-turbo attn add sbhd format support Apr 8, 2026
@RuibinCheung RuibinCheung marked this pull request as draft April 8, 2026 11:31
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds experimental support for additional QKV tensor layouts (notably sbhd) in the Primus Turbo attention wrapper, while introducing special-casing for sink attention to force a specific layout.

Changes:

  • Removes the previous manual sbhd -> bshd transpose and instead forwards qkv_format into the underlying flash_attn op.
  • Introduces a use_sink_attn flag and forces sink-attention execution to use bshd, including explicit tensor permutations for Q/K/V and the output.

Comment on lines +504 to +507
# NOTE: sink attention only support bshd format
query = query.permute(1, 0, 2, 3).contiguous()
key = key.permute(1, 0, 2, 3).contiguous()
value = value.permute(1, 0, 2, 3).contiguous()
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When use_sink_attn is enabled, query/key/value are always permuted as if the incoming layout were sbhd (S,B,H,D) -> bshd (B,S,H,D). If qkv_format is already bshd (or any non-sbhd value coming from packed_seq_params), this permutation will corrupt the tensor layout while qkv_format is forced to "bshd", creating a format/tensor mismatch.

Consider either (a) explicitly asserting qkv_format == "sbhd" before permuting in the sink-attention path, or (b) permuting conditionally based on qkv_format and ensuring o is permuted back consistently to preserve the expected output layout.

Copilot uses AI. Check for mistakes.
@@ -465,9 +465,6 @@ def forward(
)

qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qkv_format is now passed through to self.attn(...) without any local validation. Since this value can come from packed_seq_params, an unexpected value will likely fail deeper in the kernel with a less actionable error (and could also interact badly with the explicit format conversions in the sink-attention path).

Recommend validating qkv_format against the set of supported formats in this backend and raising a clear ValueError (or keeping an assert) before using it.

Suggested change
qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
supported_qkv_formats = ("sbhd", "bshd", "thd")
if qkv_format not in supported_qkv_formats:
raise ValueError(
f"Unsupported qkv_format: {qkv_format}. "
f"Supported formats: {supported_qkv_formats}"
)

Copilot uses AI. Check for mistakes.
else:
window_size = (self.sink_sliding_window, 0)

# NOTE: sink attention only support bshd format
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo/grammar in the new comment: "sink attention only support bshd format" → "sink attention only supports bshd format".

Suggested change
# NOTE: sink attention only support bshd format
# NOTE: sink attention only supports bshd format

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants