Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 96 additions & 9 deletions onnxruntime/core/optimizer/group_query_attention_fusion.cc
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,70 @@
return rotary_node_1 == nullptr || rotary_node_2 == nullptr || q_node == nullptr || k_node == nullptr || v_node == nullptr;
}

static bool NodeArgExists(const NodeArg* node_arg) {
return node_arg != nullptr && node_arg->Exists();
}

struct RotaryEmbeddingArgs {
NodeArg* cos_cache_arg = nullptr;
NodeArg* sin_cache_arg = nullptr;
NodeArg* position_ids_arg = nullptr;
int64_t interleaved = 0;
int64_t rotary_embedding_dim = 0;
};

static int64_t GetIntAttributeOrDefault(const Node& node, const std::string& attr_name, int64_t default_value) {
const auto* attr = graph_utils::GetNodeAttribute(node, attr_name);
return attr != nullptr ? attr->i() : default_value;
}

static bool TryGetRotaryEmbeddingArgs(Node& rotary_node, RotaryEmbeddingArgs& args) {
if (rotary_node.OpType() != "RotaryEmbedding") {
return false;
}

if (rotary_node.Domain() != kMSDomain && rotary_node.Domain() != kOnnxDomain) {
return false;
}

args.interleaved = GetIntAttributeOrDefault(rotary_node, "interleaved", 0);
args.rotary_embedding_dim = GetIntAttributeOrDefault(rotary_node, "rotary_embedding_dim", 0);
if ((args.interleaved != 0 && args.interleaved != 1) || args.rotary_embedding_dim < 0) {
return false;
}

auto& input_defs = rotary_node.MutableInputDefs();
if (rotary_node.Domain() == kMSDomain) {
// com.microsoft.RotaryEmbedding inputs:
// input, position_ids, cos_cache, sin_cache
if (input_defs.size() < 4 || !NodeArgExists(input_defs[1]) || !NodeArgExists(input_defs[2]) ||
!NodeArgExists(input_defs[3])) {
return false;
}
args.position_ids_arg = input_defs[1];
args.cos_cache_arg = input_defs[2];
args.sin_cache_arg = input_defs[3];
return true;
}

if (rotary_node.Domain() == kOnnxDomain) {
// ONNX RotaryEmbedding inputs:
// X, cos_cache, sin_cache, optional position_ids
// If position_ids is omitted, ONNX RotaryEmbedding uses 3D per-batch caches, which are
// incompatible with GroupQueryAttention's 2D rotary cache inputs.
if (input_defs.size() < 4 || !NodeArgExists(input_defs[1]) || !NodeArgExists(input_defs[2]) ||
!NodeArgExists(input_defs[3])) {
return false;
}
args.cos_cache_arg = input_defs[1];
args.sin_cache_arg = input_defs[2];
args.position_ids_arg = input_defs[3];
return true;
}

return false;
}

static void FusePreGQANodes(Graph& graph, Node* q_node, Node* k_node, Node* v_node, Node* rotary_node_1, Node* rotary_node_2, Node* new_node, NodeArg& new_node_output_arg) {
graph_utils::MoveAllNodeInputEdges(graph, *q_node, *new_node);

Expand Down Expand Up @@ -318,6 +382,11 @@

NodeArg* cos_cache_arg = nullptr;
NodeArg* sin_cache_arg = nullptr;
NodeArg* position_ids_arg = nullptr;
int64_t rotary_interleaved = 0;
int64_t rotary_embedding_dim = 0;
bool rotary_args_set = false;
bool rotary_args_mismatch = false;
NodeArg* past_key_values_key_arg = node.MutableInputDefs()[3];
NodeArg* past_key_values_value_arg = node.MutableInputDefs()[4];
NodeArg* seqlens_k = node.MutableInputDefs()[5];
Expand All @@ -334,7 +403,8 @@
for (auto pre_gqa_node = node.InputNodesBegin(); pre_gqa_node != node.InputNodesEnd(); ++pre_gqa_node) {
Node& rotary_or_v_node = *graph.GetNode(pre_gqa_node->Index());

if (rotary_or_v_node.OpType() == "RotaryEmbedding") {
RotaryEmbeddingArgs rotary_args;
if (TryGetRotaryEmbeddingArgs(rotary_or_v_node, rotary_args)) {
if (!rotary_node_1) {
rotary_node_1 = &rotary_or_v_node;
} else {
Expand All @@ -357,19 +427,28 @@
}
}

if (cos_cache_arg == nullptr) {
cos_cache_arg = rotary_or_v_node.MutableInputDefs()[2];
}

if (sin_cache_arg == nullptr) {
sin_cache_arg = rotary_or_v_node.MutableInputDefs()[3];
if (!rotary_args_set) {
cos_cache_arg = rotary_args.cos_cache_arg;
sin_cache_arg = rotary_args.sin_cache_arg;
position_ids_arg = rotary_args.position_ids_arg;
rotary_interleaved = rotary_args.interleaved;
rotary_embedding_dim = rotary_args.rotary_embedding_dim;
rotary_args_set = true;
} else if (cos_cache_arg != rotary_args.cos_cache_arg ||
sin_cache_arg != rotary_args.sin_cache_arg ||
position_ids_arg != rotary_args.position_ids_arg ||
rotary_interleaved != rotary_args.interleaved ||
rotary_embedding_dim != rotary_args.rotary_embedding_dim) {
rotary_args_mismatch = true;
}
} else if (rotary_or_v_node.OpType() == "MatMulNBits" || rotary_or_v_node.OpType() == "MatMul") {
v_node = &rotary_or_v_node;
}
}

if (CheckIfAnyOfRequiredGQANodesDoesNotExist(rotary_node_1, rotary_node_2, q_node, k_node, v_node)) {
if (rotary_args_mismatch ||
CheckIfAnyOfRequiredGQANodesDoesNotExist(rotary_node_1, rotary_node_2, q_node, k_node, v_node) ||
cos_cache_arg == nullptr || sin_cache_arg == nullptr) {
// Some of the required pre-GQA nodes required for fusion were not retrieved,
// this can be expected if the model has extra nodes in between MatMuls and rotary embeddings.
continue;
Expand Down Expand Up @@ -489,11 +568,13 @@
FusePreGQANodes(graph, q_node, k_node, v_node, rotary_node_1, rotary_node_2, mat_mul_or_n_bits_new_node, matmul_or_nbits_output);

node.GetMutableAttributes()["do_rotary"] = ONNX_NAMESPACE::MakeAttribute("do_rotary", static_cast<int64_t>(1));
node.GetMutableAttributes()["rotary_interleaved"] =
ONNX_NAMESPACE::MakeAttribute("rotary_interleaved", rotary_interleaved);

std::string empty_name;
auto& empty_node_arg = graph.GetOrCreateNodeArg(empty_name, nullptr);

const std::array gqa_input_defs{
std::vector<NodeArg*> gqa_input_defs{

Check warning on line 577 in onnxruntime/core/optimizer/group_query_attention_fusion.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <vector> for vector<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/optimizer/group_query_attention_fusion.cc:577: Add #include <vector> for vector<> [build/include_what_you_use] [4]
&matmul_or_nbits_output,
&empty_node_arg,
&empty_node_arg,
Expand All @@ -503,10 +584,16 @@
total_seq_len,
cos_cache_arg,
sin_cache_arg};
if (position_ids_arg != nullptr) {
gqa_input_defs.push_back(position_ids_arg);

@tianleiwu tianleiwu Jun 29, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This forwards ONNX RotaryEmbedding position_ids into GQA, but GQA does not preserve the same prompt-time semantics. ONNX RotaryEmbedding treats a provided position_ids input as a full (batch_size, sequence_length) tensor and reads every token position. In the fused GQA path, prompt handling uses base-offset semantics: the CPU path sets position_ids_format = !parameters.is_first_prompt ? 1 : 0, and the CUDA paths similarly route through GQA rotary helpers instead of the ONNX op. So for first-prompt/prefill cases with non-contiguous or per-batch custom 2D positions, this fusion can silently rotate Q/K with different positions than the original ONNX nodes.

Can we either skip ONNX-domain fusion unless the position_ids are known to be contiguous/base-offset compatible, or update GQA RoPE to consume the full 2D position_ids tensor for prompt cases before enabling this rewrite?

}

auto& gqa_input_args = node.MutableInputArgsCount();
gqa_input_args[7] = 1;
gqa_input_args[8] = 1;
if (position_ids_arg != nullptr) {
gqa_input_args[9] = 1;
}

// Switch GQA input defs from unfused into the fused form.
auto& gqa_node_input_defs = node.MutableInputDefs();
Expand Down
196 changes: 196 additions & 0 deletions onnxruntime/test/optimizer/graph_transform_test_layernorm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,6 +607,143 @@ static void TestGQAFusion(const std::basic_string<ORTCHAR_T>& file_path, int mat
ASSERT_TRUE(op_to_count["com.microsoft.GroupQueryAttention"] == 1);
}

static void BuildOnnxRotaryEmbeddingGQAFusionGraph(ModelTestBuilder& builder,
bool include_position_ids,
int64_t q_interleaved = 0,
int64_t k_interleaved = 0,
int64_t rotary_embedding_dim = 0) {
constexpr int64_t batch_size = 1;
constexpr int64_t sequence_length = 2;
constexpr int64_t input_hidden_size = 8;
constexpr int64_t num_heads = 2;
constexpr int64_t kv_num_heads = 1;
const int64_t head_size = rotary_embedding_dim == 0 ? 16 : 32;
const int64_t q_hidden_size = num_heads * head_size;
const int64_t kv_hidden_size = kv_num_heads * head_size;
constexpr int64_t max_sequence_length = 8;
const int64_t half_rotary_dim = (rotary_embedding_dim == 0 ? head_size : rotary_embedding_dim) / 2;

auto make_weight = [&builder](int64_t rows, int64_t cols, float value) {
return builder.MakeInitializer<MLFloat16>(
{rows, cols}, std::vector<MLFloat16>(static_cast<size_t>(rows * cols), MLFloat16(value)));
};

NodeArg* input = builder.MakeInput<MLFloat16>({{batch_size, sequence_length, input_hidden_size}});
NodeArg* q_weight = make_weight(input_hidden_size, q_hidden_size, 0.5f);
NodeArg* k_weight = make_weight(input_hidden_size, kv_hidden_size, 0.25f);
NodeArg* v_weight = make_weight(input_hidden_size, kv_hidden_size, 0.125f);

NodeArg* q_matmul_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, q_hidden_size});
NodeArg* k_matmul_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, kv_hidden_size});
NodeArg* v_matmul_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, kv_hidden_size});
builder.AddNode("MatMul", {input, q_weight}, {q_matmul_out});
builder.AddNode("MatMul", {input, k_weight}, {k_matmul_out});
builder.AddNode("MatMul", {input, v_weight}, {v_matmul_out});

const std::vector<int64_t> cache_shape = include_position_ids
? std::vector<int64_t>{max_sequence_length, half_rotary_dim}
: std::vector<int64_t>{batch_size, sequence_length, half_rotary_dim};
NodeArg* cos_cache = builder.MakeInput<MLFloat16>(cache_shape);
NodeArg* sin_cache = builder.MakeInput<MLFloat16>(cache_shape);
NodeArg* position_ids = include_position_ids
? builder.MakeInput<int64_t>({{batch_size, sequence_length}})
: nullptr;

NodeArg* q_rotary_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, q_hidden_size});
NodeArg* k_rotary_out =
builder.MakeIntermediate<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, kv_hidden_size});

std::vector<NodeArg*> q_rotary_inputs{q_matmul_out, cos_cache, sin_cache};
std::vector<NodeArg*> k_rotary_inputs{k_matmul_out, cos_cache, sin_cache};
if (position_ids != nullptr) {
q_rotary_inputs.push_back(position_ids);
k_rotary_inputs.push_back(position_ids);
}

Node& q_rotary = builder.AddNode("RotaryEmbedding", q_rotary_inputs, {q_rotary_out}, kOnnxDomain);
q_rotary.AddAttribute("num_heads", num_heads);
q_rotary.AddAttribute("interleaved", q_interleaved);
q_rotary.AddAttribute("rotary_embedding_dim", rotary_embedding_dim);
Node& k_rotary = builder.AddNode("RotaryEmbedding", k_rotary_inputs, {k_rotary_out}, kOnnxDomain);
k_rotary.AddAttribute("num_heads", kv_num_heads);
k_rotary.AddAttribute("interleaved", k_interleaved);
k_rotary.AddAttribute("rotary_embedding_dim", rotary_embedding_dim);

NodeArg* past_key =
builder.MakeInput<MLFloat16>({{batch_size, kv_num_heads, max_sequence_length, head_size}});
NodeArg* past_value =
builder.MakeInput<MLFloat16>({{batch_size, kv_num_heads, max_sequence_length, head_size}});
NodeArg* seqlens_k = builder.MakeInput<int32_t>({{batch_size}});
NodeArg* total_sequence_length = builder.MakeInput<int32_t>({{1}});
NodeArg* gqa_output =
builder.MakeOutput<MLFloat16>(std::vector<int64_t>{batch_size, sequence_length, q_hidden_size});

Node& gqa = builder.AddNode("GroupQueryAttention",
{q_rotary_out, k_rotary_out, v_matmul_out, past_key, past_value,
seqlens_k, total_sequence_length},
{gqa_output},
kMSDomain);
gqa.AddAttribute("num_heads", num_heads);
gqa.AddAttribute("kv_num_heads", kv_num_heads);
}

static Status CheckOnnxRotaryEmbeddingGQAFusedWithInterleaved(Graph& graph, int64_t expected_rotary_interleaved) {
const auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "RotaryEmbedding") == 0);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "MatMul") == 1);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "com.microsoft.GroupQueryAttention") == 1);

for (const Node& node : graph.Nodes()) {
if (node.OpType() != "GroupQueryAttention") {
continue;
}

TEST_RETURN_IF_NOT(node.InputDefs().size() == 10);
TEST_RETURN_IF_NOT(node.InputDefs()[7] != nullptr && node.InputDefs()[7]->Exists());
TEST_RETURN_IF_NOT(node.InputDefs()[8] != nullptr && node.InputDefs()[8]->Exists());
TEST_RETURN_IF_NOT(node.InputDefs()[9] != nullptr && node.InputDefs()[9]->Exists());

const auto& attrs = node.GetAttributes();
auto do_rotary_attr = attrs.find("do_rotary");
TEST_RETURN_IF_NOT(do_rotary_attr != attrs.end());
TEST_RETURN_IF_NOT(do_rotary_attr->second.i() == 1);

auto rotary_interleaved_attr = attrs.find("rotary_interleaved");
TEST_RETURN_IF_NOT(rotary_interleaved_attr != attrs.end());
TEST_RETURN_IF_NOT(rotary_interleaved_attr->second.i() == expected_rotary_interleaved);
}

return Status::OK();
}

static Status CheckOnnxRotaryEmbeddingGQAFused(Graph& graph) {
return CheckOnnxRotaryEmbeddingGQAFusedWithInterleaved(graph, 0);
}

static Status CheckOnnxRotaryEmbeddingGQANotFused(Graph& graph) {
const auto op_to_count = CountOpsInGraph(graph);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "RotaryEmbedding") == 2);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "MatMul") == 3);
TEST_RETURN_IF_NOT(OpCount(op_to_count, "com.microsoft.GroupQueryAttention") == 1);

for (const Node& node : graph.Nodes()) {
if (node.OpType() != "GroupQueryAttention") {
continue;
}

TEST_RETURN_IF_NOT(node.InputDefs().size() == 7);
const auto& attrs = node.GetAttributes();
auto do_rotary_attr = attrs.find("do_rotary");
TEST_RETURN_IF_NOT(do_rotary_attr == attrs.end() || do_rotary_attr->second.i() == 0);
}

return Status::OK();
}

static void TestSkipLayerNormFusion(const std::basic_string<ORTCHAR_T>& file_path, int add_count, int ln_count,
int skip_ln_count, int cast_count, logging::Logger* logger) {
std::shared_ptr<Model> p_model;
Expand Down Expand Up @@ -796,6 +933,65 @@ TEST_F(GraphTransformationTests, GroupQueryAttentionFusionTest) {
TestGQAFusion(MODEL_FOLDER "fusion/gqa_fusion_quantized_different_head_sizes.onnx", 1, 0, logger_.get());
}

TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_,
std::make_unique<GroupQueryAttentionFusion>(),
TransformerLevel::Level2, 3, nullptr,
CheckOnnxRotaryEmbeddingGQAFused));
}

TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingInterleavedTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true, 1, 1);
};

auto post_graph_checker = [](Graph& graph) {
return CheckOnnxRotaryEmbeddingGQAFusedWithInterleaved(graph, 1);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_,
std::make_unique<GroupQueryAttentionFusion>(),
TransformerLevel::Level2, 3, nullptr,
post_graph_checker));
}

TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingPartialRotaryTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true, 0, 0, 16);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_,
std::make_unique<GroupQueryAttentionFusion>(),
TransformerLevel::Level2, 3, nullptr,
CheckOnnxRotaryEmbeddingGQAFused));
}

TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingInterleavedMismatchTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, true, 0, 1);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_,
std::make_unique<GroupQueryAttentionFusion>(),
TransformerLevel::Level2, 3, nullptr,
CheckOnnxRotaryEmbeddingGQANotFused));
}

TEST_F(GraphTransformationTests, GroupQueryAttentionFusionOnnxRotaryEmbeddingNoPositionIdsTest) {
auto build_test_case = [](ModelTestBuilder& builder) {
BuildOnnxRotaryEmbeddingGQAFusionGraph(builder, false);
};

ASSERT_STATUS_OK(TestGraphTransformer(build_test_case, 23, *logger_,
std::make_unique<GroupQueryAttentionFusion>(),
TransformerLevel::Level2, 3, nullptr,
CheckOnnxRotaryEmbeddingGQANotFused));
}

TEST_F(GraphTransformationTests, SkipLayerNormFusionWithCastTest) {
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format1_with_cast.onnx", 0, 0, 1, 3, logger_.get());
TestSkipLayerNormFusion(MODEL_FOLDER "fusion/skip_layer_norm_format2_with_cast.onnx", 0, 0, 1, 3, logger_.get());
Expand Down
Loading