Add rotary embedding onnx domain support#29261
Conversation
There was a problem hiding this comment.
Pull request overview
This PR updates the GroupQueryAttention fusion optimizer pass to recognize standard (ONNX-domain) RotaryEmbedding nodes when extracting the cos_cache/sin_cache inputs needed to fuse rotary embedding into the com.microsoft.GroupQueryAttention node.
Changes:
- Adds a helper to retrieve
cos_cache/sin_cacheNodeArgs for bothcom.microsoft.RotaryEmbeddingand ONNX-domainRotaryEmbedding(different input ordering). - Updates the fusion pattern-matching logic to use that helper and to require that rotary cache inputs were successfully identified before fusing.
|
cc @tianleiwu |
tianleiwu
left a comment
There was a problem hiding this comment.
Summary
The change extends GroupQueryAttentionFusion to also match the standard ONNX-domain RotaryEmbedding (X, cos_cache, sin_cache, position_ids) in addition to com.microsoft.RotaryEmbedding, plumbing position_ids through to GQA input 9 and setting the input-arg-count. The approach is sound:
- Requiring
position_idsto be present for the ONNX-domain path (and bailing otherwise) correctly avoids the 3D per-batch cos/sin cache form that GQA's 2D rotary cache validation cannot consume. - The
position_ids_arg_mismatchguard and the addedcos_cache_arg == nullptr || sin_cache_arg == nullptrchecks make the fusion safely skip ambiguous/mixed cases. MutableInputArgsCount()[9]is in-bounds because the GQA schema declares formal inputs up to index 11, soUpdateInputArgCount()sizes the vector accordingly.- Both the fused (with position_ids) and non-fused (omitted position_ids) cases are covered by new tests.
Main concern
Rotary interleaved / rotary_embedding_dim attributes are not validated or propagated. GQA's do_rotary path runs non-interleaved, full-width RoPE (rotary_interleaved defaults to 0 and the fusion never sets it). A standard ONNX RotaryEmbedding with interleaved=1, or a partial-rotary node with rotary_embedding_dim > 0 (and a correspondingly narrower cos/sin cache), is silently fused into a GQA that applies a different rotation — producing incorrect results with no error. Since this PR specifically targets a standard-ONNX export path where interleaved RoPE is common, the fusion should either verify interleaved == 0 and rotary_embedding_dim == 0 (full rotary) before matching, or propagate interleaved to GQA's rotary_interleaved. Inline comment below. (Note: the pre-existing com.microsoft.RotaryEmbedding path has the same latent gap.)
Minor
- Only
position_idsis checked for consistency between the two rotary nodes;cos_cache/sin_cacheare taken from whichever rotary node is visited first without verifying the second uses the same caches. In practice Q/K share caches, but an explicit equality guard (mirroring theposition_ids_arg_mismatchcheck) would make the fusion robust to malformed graphs.
| return false; | ||
| } | ||
| cos_cache_arg = input_defs[1]; | ||
| sin_cache_arg = input_defs[2]; |
There was a problem hiding this comment.
TryGetRotaryEmbeddingArgs matches an ONNX RotaryEmbedding purely by op type, domain, and input presence, but never inspects the interleaved or rotary_embedding_dim attributes. GQA's do_rotary applies non-interleaved, full-width RoPE (rotary_interleaved defaults to 0 and is not set anywhere in this pass). So a node with interleaved=1, or partial rotary (rotary_embedding_dim > 0 with a narrower cos/sin cache), would be silently fused into a GQA that computes a different rotation — a silent numerical mismatch rather than a hard failure.
Consider rejecting the match when interleaved != 0 or rotary_embedding_dim != 0 (full rotary only), or propagate interleaved onto the fused GQA's rotary_interleaved attribute. The same gap exists for the com.microsoft.RotaryEmbedding branch above.
There was a problem hiding this comment.
Fixed by propagating interleaved to GQA’s rotary_interleaved and rejecting Q/K RotaryEmbedding mismatches.
Description
Mobius exports standard ONNX rotary embedding op. Adding support for this.
Motivation and Context