[No Merge][WIP] feat: primus-turbo attn add sbhd format support#650
[No Merge][WIP] feat: primus-turbo attn add sbhd format support#650RuibinCheung wants to merge 2 commits intomainfrom
Conversation
There was a problem hiding this comment.
Pull request overview
Adds experimental support for additional QKV tensor layouts (notably sbhd) in the Primus Turbo attention wrapper, while introducing special-casing for sink attention to force a specific layout.
Changes:
- Removes the previous manual
sbhd -> bshdtranspose and instead forwardsqkv_formatinto the underlyingflash_attnop. - Introduces a
use_sink_attnflag and forces sink-attention execution to usebshd, including explicit tensor permutations for Q/K/V and the output.
| # NOTE: sink attention only support bshd format | ||
| query = query.permute(1, 0, 2, 3).contiguous() | ||
| key = key.permute(1, 0, 2, 3).contiguous() | ||
| value = value.permute(1, 0, 2, 3).contiguous() |
There was a problem hiding this comment.
When use_sink_attn is enabled, query/key/value are always permuted as if the incoming layout were sbhd (S,B,H,D) -> bshd (B,S,H,D). If qkv_format is already bshd (or any non-sbhd value coming from packed_seq_params), this permutation will corrupt the tensor layout while qkv_format is forced to "bshd", creating a format/tensor mismatch.
Consider either (a) explicitly asserting qkv_format == "sbhd" before permuting in the sink-attention path, or (b) permuting conditionally based on qkv_format and ensuring o is permuted back consistently to preserve the expected output layout.
| @@ -465,9 +465,6 @@ def forward( | |||
| ) | |||
|
|
|||
| qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format) | |||
There was a problem hiding this comment.
qkv_format is now passed through to self.attn(...) without any local validation. Since this value can come from packed_seq_params, an unexpected value will likely fail deeper in the kernel with a less actionable error (and could also interact badly with the explicit format conversions in the sink-attention path).
Recommend validating qkv_format against the set of supported formats in this backend and raising a clear ValueError (or keeping an assert) before using it.
| qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format) | |
| qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format) | |
| supported_qkv_formats = ("sbhd", "bshd", "thd") | |
| if qkv_format not in supported_qkv_formats: | |
| raise ValueError( | |
| f"Unsupported qkv_format: {qkv_format}. " | |
| f"Supported formats: {supported_qkv_formats}" | |
| ) |
| else: | ||
| window_size = (self.sink_sliding_window, 0) | ||
|
|
||
| # NOTE: sink attention only support bshd format |
There was a problem hiding this comment.
Typo/grammar in the new comment: "sink attention only support bshd format" → "sink attention only supports bshd format".
| # NOTE: sink attention only support bshd format | |
| # NOTE: sink attention only supports bshd format |
This PR is experimental. Please not merge!