From 37b7550a1efac4f8993745a426b030e0d1062f76 Mon Sep 17 00:00:00 2001 From: Theodore Cooper Date: Sat, 20 Jun 2026 19:52:08 +0800 Subject: [PATCH 1/6] fix: SimplifiedLayerNorm fusion with node-produced Pow exponent --- .../core/optimizer/layer_norm_fusion.cc | 28 ++++++++- .../graph_transform_test_layernorm.cc | 57 +++++++++++++++++++ 2 files changed, 84 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index c76c5c7e340c4..ee65e9e4fa151 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -752,11 +752,37 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr // Assign provider to this new node. Provider should be same as the provider for old node. layer_norm_node.SetExecutionProviderType(reduce_mean_node.GetExecutionProviderType()); - // move input edges to add (first in list) across to the layer_norm_node. + // FinalizeNodeFusion moves every input edge of the first node by NodeArg name. Disconnect inputs + // that the replacement does not use, such as a Pow exponent produced by a mixed-precision Cast. + // Keep track of their producers so they can be removed if this fusion makes them dead. + InlinedVector unused_input_node_indices; + const auto first_node_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(nodes_to_remove.front().get()); + for (const auto& input_edge : first_node_input_edges) { + const bool is_replacement_input = + std::any_of(layer_norm_input_defs.cbegin(), layer_norm_input_defs.cend(), + [&input_edge](const NodeArg* input) { return input->Name() == input_edge.arg_name; }); + if (!is_replacement_input) { + unused_input_node_indices.push_back(input_edge.src_node); + graph.RemoveEdge(input_edge.src_node, input_edge.dst_node, + input_edge.src_arg_index, input_edge.dst_arg_index); + } + } + + // move input edges to pow (first in list) across to the layer_norm_node. // move output definitions and output edges from mul_node (last in list) to layer_norm_node. // remove all the other nodes. graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, layer_norm_node); + // Remove unused input producers and any newly dead upstream nodes only after their final consumer is + // fused. A producer can be shared by multiple matched subgraphs, so it must remain while it still has users. + for (const NodeIndex unused_input_node_index : unused_input_node_indices) { + Node* unused_input_node = graph.GetNode(unused_input_node_index); + if (unused_input_node != nullptr && unused_input_node->GetOutputEdgesCount() == 0 && + !graph.NodeProducesGraphOutput(*unused_input_node)) { + graph_utils::RemoveNodesWithOneOutputBottomUp(graph, *unused_input_node); + } + } + #ifdef ENABLE_TRAINING_CORE // add one extra output def, so we have 2 output defs that match what gradient builder expected layer_norm_node.MutableOutputDefs().push_back( diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index d5aaf0bb2d2ee..2d52d88cb7898 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -493,6 +493,63 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { } } +TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionSharedCastPowExponent) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* pow_exponent_fp16 = + builder.MakeInitializer({}, {MLFloat16(2.0f)}); + auto* pow_exponent = builder.MakeIntermediate(); + builder.AddNode("Cast", {pow_exponent_fp16}, {pow_exponent}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + + auto* epsilon = builder.MakeInitializer({}, {1e-5f}); + auto* scale = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + + auto add_simplified_layer_norm = [&](NodeArg* input) { + auto* pow_out = builder.MakeIntermediate(); + auto* reduce_mean_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sqrt_out = builder.MakeIntermediate(); + auto* div_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Pow", {input, pow_exponent}, {pow_out}); + builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out}) + .AddAttribute("axes", std::vector{-1}); + builder.AddNode("Add", {reduce_mean_out, epsilon}, {add_out}); + builder.AddNode("Sqrt", {add_out}, {sqrt_out}); + builder.AddNode("Div", {input, sqrt_out}, {div_out}); + builder.AddNode("Mul", {div_out, scale}, {output}); + }; + + add_simplified_layer_norm(builder.MakeInput({{2, 4}})); + add_simplified_layer_norm(builder.MakeInput({{2, 4}})); + }; + + auto pre_graph_checker = [](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("Cast") == 1); + TEST_RETURN_IF_NOT(op_to_count.at("Pow") == 2); + return Status::OK(); + }; + + auto post_graph_checker = [](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.at("SimplifiedLayerNormalization") == 2); + TEST_RETURN_IF_NOT(op_to_count.find("Cast") == op_to_count.end()); + TEST_RETURN_IF_NOT(op_to_count.find("Pow") == op_to_count.end()); + TEST_RETURN_IF_NOT(op_to_count.find("ReduceMean") == op_to_count.end()); + TEST_RETURN_IF_NOT(op_to_count.find("Add") == op_to_count.end()); + TEST_RETURN_IF_NOT(op_to_count.find("Sqrt") == op_to_count.end()); + TEST_RETURN_IF_NOT(op_to_count.find("Div") == op_to_count.end()); + TEST_RETURN_IF_NOT(op_to_count.find("Mul") == op_to_count.end()); + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer( + build_test_case, 17, *logger_, std::make_unique(), + TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); +} + // It tests the scenario when scale or bias are not Graph Inputs and not initialized in Graph // To test this added a Identity node after Scale and Bias terms to ensure LayerNormFusion works properly TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) { From d9c1d892ed70c5e4eda5f2ae5bcf1eaac1afca9e Mon Sep 17 00:00:00 2001 From: Theodore Cooper Date: Sun, 21 Jun 2026 02:40:25 +0800 Subject: [PATCH 2/6] test: avoid exceptions in SimplifiedLayerNorm fusion graph checks --- onnxruntime/core/optimizer/layer_norm_fusion.cc | 2 +- .../test/optimizer/graph_transform_test_layernorm.cc | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index ee65e9e4fa151..8da12d43d6afc 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -768,7 +768,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr } } - // move input edges to pow (first in list) across to the layer_norm_node. + // move input edges from the first node in nodes_to_remove to layer_norm_node. // move output definitions and output edges from mul_node (last in list) to layer_norm_node. // remove all the other nodes. graph_utils::FinalizeNodeFusion(graph, nodes_to_remove, layer_norm_node); diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 2d52d88cb7898..d39ffcc281f30 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -527,14 +527,18 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionSharedCastPowExponent) auto pre_graph_checker = [](Graph& graph) { const auto op_to_count = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_to_count.at("Cast") == 1); - TEST_RETURN_IF_NOT(op_to_count.at("Pow") == 2); + const auto cast_it = op_to_count.find("Cast"); + TEST_RETURN_IF_NOT(cast_it != op_to_count.end() && cast_it->second == 1); + const auto pow_it = op_to_count.find("Pow"); + TEST_RETURN_IF_NOT(pow_it != op_to_count.end() && pow_it->second == 2); return Status::OK(); }; auto post_graph_checker = [](Graph& graph) { const auto op_to_count = CountOpsInGraph(graph); - TEST_RETURN_IF_NOT(op_to_count.at("SimplifiedLayerNormalization") == 2); + const auto simplified_layer_norm_it = op_to_count.find("SimplifiedLayerNormalization"); + TEST_RETURN_IF_NOT(simplified_layer_norm_it != op_to_count.end() && + simplified_layer_norm_it->second == 2); TEST_RETURN_IF_NOT(op_to_count.find("Cast") == op_to_count.end()); TEST_RETURN_IF_NOT(op_to_count.find("Pow") == op_to_count.end()); TEST_RETURN_IF_NOT(op_to_count.find("ReduceMean") == op_to_count.end()); From 40ab12064b8440ae04e12d94ea973f1ba4469a21 Mon Sep 17 00:00:00 2001 From: Theodore Cooper Date: Sun, 21 Jun 2026 13:36:54 +0800 Subject: [PATCH 3/6] fix: clean up Pow exponent producer with leading Cast --- .../core/optimizer/layer_norm_fusion.cc | 10 +- .../graph_transform_test_layernorm.cc | 123 ++++++++++-------- 2 files changed, 78 insertions(+), 55 deletions(-) diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 8da12d43d6afc..1a44ccb3b7aa3 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -755,14 +755,20 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr // FinalizeNodeFusion moves every input edge of the first node by NodeArg name. Disconnect inputs // that the replacement does not use, such as a Pow exponent produced by a mixed-precision Cast. // Keep track of their producers so they can be removed if this fusion makes them dead. - InlinedVector unused_input_node_indices; + InlinedHashSet unused_input_node_indices; + // Pow may follow a leading Cast and not be the first node finalized. Track its exponent producer + // explicitly because removing Pow will disconnect that edge without moving it to the replacement. + if (const Node* pow_exponent_input_node = graph_utils::GetInputNode(pow_node, 1)) { + unused_input_node_indices.insert(pow_exponent_input_node->Index()); + } + const auto first_node_input_edges = graph_utils::GraphEdge::GetNodeInputEdges(nodes_to_remove.front().get()); for (const auto& input_edge : first_node_input_edges) { const bool is_replacement_input = std::any_of(layer_norm_input_defs.cbegin(), layer_norm_input_defs.cend(), [&input_edge](const NodeArg* input) { return input->Name() == input_edge.arg_name; }); if (!is_replacement_input) { - unused_input_node_indices.push_back(input_edge.src_node); + unused_input_node_indices.insert(input_edge.src_node); graph.RemoveEdge(input_edge.src_node, input_edge.dst_node, input_edge.src_arg_index, input_edge.dst_arg_index); } diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index d39ffcc281f30..6777c015b558f 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -494,64 +494,81 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionTest) { } TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionSharedCastPowExponent) { - auto build_test_case = [](ModelTestBuilder& builder) { - auto* pow_exponent_fp16 = - builder.MakeInitializer({}, {MLFloat16(2.0f)}); - auto* pow_exponent = builder.MakeIntermediate(); - builder.AddNode("Cast", {pow_exponent_fp16}, {pow_exponent}) - .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + for (const bool has_leading_cast : {false, true}) { + auto build_test_case = [has_leading_cast](ModelTestBuilder& builder) { + auto* pow_exponent_fp16 = + builder.MakeInitializer({}, {MLFloat16(2.0f)}); + auto* pow_exponent = builder.MakeIntermediate(); + builder.AddNode("Cast", {pow_exponent_fp16}, {pow_exponent}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + + auto* epsilon = builder.MakeInitializer({}, {1e-5f}); + auto* scale = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + + auto add_simplified_layer_norm = [&](NodeArg* input) { + NodeArg* layer_norm_input = input; + if (has_leading_cast) { + layer_norm_input = builder.MakeIntermediate(); + builder.AddNode("Cast", {input}, {layer_norm_input}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + } - auto* epsilon = builder.MakeInitializer({}, {1e-5f}); - auto* scale = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); - - auto add_simplified_layer_norm = [&](NodeArg* input) { - auto* pow_out = builder.MakeIntermediate(); - auto* reduce_mean_out = builder.MakeIntermediate(); - auto* add_out = builder.MakeIntermediate(); - auto* sqrt_out = builder.MakeIntermediate(); - auto* div_out = builder.MakeIntermediate(); - auto* output = builder.MakeOutput(); - - builder.AddNode("Pow", {input, pow_exponent}, {pow_out}); - builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out}) - .AddAttribute("axes", std::vector{-1}); - builder.AddNode("Add", {reduce_mean_out, epsilon}, {add_out}); - builder.AddNode("Sqrt", {add_out}, {sqrt_out}); - builder.AddNode("Div", {input, sqrt_out}, {div_out}); - builder.AddNode("Mul", {div_out, scale}, {output}); + auto* pow_out = builder.MakeIntermediate(); + auto* reduce_mean_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sqrt_out = builder.MakeIntermediate(); + auto* div_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Pow", {layer_norm_input, pow_exponent}, {pow_out}); + builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out}) + .AddAttribute("axes", std::vector{-1}); + builder.AddNode("Add", {reduce_mean_out, epsilon}, {add_out}); + builder.AddNode("Sqrt", {add_out}, {sqrt_out}); + builder.AddNode("Div", {layer_norm_input, sqrt_out}, {div_out}); + builder.AddNode("Mul", {div_out, scale}, {output}); + }; + + auto make_input = [&]() -> NodeArg* { + return has_leading_cast ? builder.MakeInput({{2, 4}}) + : builder.MakeInput({{2, 4}}); + }; + + add_simplified_layer_norm(make_input()); + add_simplified_layer_norm(make_input()); }; - add_simplified_layer_norm(builder.MakeInput({{2, 4}})); - add_simplified_layer_norm(builder.MakeInput({{2, 4}})); - }; - - auto pre_graph_checker = [](Graph& graph) { - const auto op_to_count = CountOpsInGraph(graph); - const auto cast_it = op_to_count.find("Cast"); - TEST_RETURN_IF_NOT(cast_it != op_to_count.end() && cast_it->second == 1); - const auto pow_it = op_to_count.find("Pow"); - TEST_RETURN_IF_NOT(pow_it != op_to_count.end() && pow_it->second == 2); - return Status::OK(); - }; + auto pre_graph_checker = [has_leading_cast](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + const auto cast_it = op_to_count.find("Cast"); + const int expected_cast_count = has_leading_cast ? 3 : 1; + TEST_RETURN_IF_NOT(cast_it != op_to_count.end() && cast_it->second == expected_cast_count); + const auto pow_it = op_to_count.find("Pow"); + TEST_RETURN_IF_NOT(pow_it != op_to_count.end() && pow_it->second == 2); + return Status::OK(); + }; - auto post_graph_checker = [](Graph& graph) { - const auto op_to_count = CountOpsInGraph(graph); - const auto simplified_layer_norm_it = op_to_count.find("SimplifiedLayerNormalization"); - TEST_RETURN_IF_NOT(simplified_layer_norm_it != op_to_count.end() && - simplified_layer_norm_it->second == 2); - TEST_RETURN_IF_NOT(op_to_count.find("Cast") == op_to_count.end()); - TEST_RETURN_IF_NOT(op_to_count.find("Pow") == op_to_count.end()); - TEST_RETURN_IF_NOT(op_to_count.find("ReduceMean") == op_to_count.end()); - TEST_RETURN_IF_NOT(op_to_count.find("Add") == op_to_count.end()); - TEST_RETURN_IF_NOT(op_to_count.find("Sqrt") == op_to_count.end()); - TEST_RETURN_IF_NOT(op_to_count.find("Div") == op_to_count.end()); - TEST_RETURN_IF_NOT(op_to_count.find("Mul") == op_to_count.end()); - return Status::OK(); - }; + auto post_graph_checker = [](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + const auto simplified_layer_norm_it = op_to_count.find("SimplifiedLayerNormalization"); + TEST_RETURN_IF_NOT(simplified_layer_norm_it != op_to_count.end() && + simplified_layer_norm_it->second == 2); + TEST_RETURN_IF_NOT(op_to_count.find("Cast") == op_to_count.end()); + TEST_RETURN_IF_NOT(op_to_count.find("Pow") == op_to_count.end()); + TEST_RETURN_IF_NOT(op_to_count.find("ReduceMean") == op_to_count.end()); + TEST_RETURN_IF_NOT(op_to_count.find("Add") == op_to_count.end()); + TEST_RETURN_IF_NOT(op_to_count.find("Sqrt") == op_to_count.end()); + TEST_RETURN_IF_NOT(op_to_count.find("Div") == op_to_count.end()); + TEST_RETURN_IF_NOT(op_to_count.find("Mul") == op_to_count.end()); + return Status::OK(); + }; - ASSERT_STATUS_OK(TestGraphTransformer( - build_test_case, 17, *logger_, std::make_unique(), - TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); + const InlinedHashSet compatible_eps; + ASSERT_STATUS_OK(TestGraphTransformer( + build_test_case, 17, *logger_, + std::make_unique(compatible_eps, has_leading_cast), + TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); + } } // It tests the scenario when scale or bias are not Graph Inputs and not initialized in Graph From fee778a47400b3cefb10a0169e6f2e10d9bcf7db Mon Sep 17 00:00:00 2001 From: Theodore Cooper Date: Tue, 23 Jun 2026 10:42:07 +0800 Subject: [PATCH 4/6] Add optimization-loop coverage for SimplifiedLayerNorm shared exponent fusion --- .../graph_transform_test_layernorm.cc | 108 ++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index 6777c015b558f..a7f481fd4a9a2 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -14,6 +14,7 @@ #include "core/graph/graph_utils.h" #include "core/graph/graph_viewer.h" #include "core/graph/model.h" +#include "core/mlas/inc/mlas.h" #include "core/optimizer/initializer.h" #include "core/optimizer/bias_skip_layer_norm_fusion.h" @@ -21,6 +22,7 @@ #include "core/optimizer/group_query_attention_fusion.h" #include "core/optimizer/layer_norm_fusion.h" #include "core/optimizer/skip_layer_norm_fusion.h" +#include "core/session/onnxruntime_session_options_config_keys.h" #include "test/capturing_sink.h" #include "test/unittest_util/framework_test_utils.h" @@ -571,6 +573,112 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionSharedCastPowExponent) } } +TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionOptimizationLoopSharedPowExponent) { + if (MlasFp16AccelerationSupported()) { + GTEST_SKIP() << "Skipping test because FP16 acceleration support can avoid the CPU fp16 fallback path."; + } + + auto build_test_case = [](ModelTestBuilder& builder) { + auto* pow_exponent = builder.MakeInitializer({}, {MLFloat16(2.0f)}); + + auto add_simplified_layer_norm = [&](const std::vector& input_data) { + auto* input = builder.MakeInput({2, 4}, input_data); + auto* epsilon = builder.MakeInitializer({}, {MLFloat16(1e-4f)}); + auto* scale = builder.MakeInitializer( + {4}, {MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f), MLFloat16(1.0f)}); + + auto* pow_out = builder.MakeIntermediate(); + auto* reduce_mean_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sqrt_out = builder.MakeIntermediate(); + auto* div_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Pow", {input, pow_exponent}, {pow_out}); + builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out}) + .AddAttribute("axes", std::vector{-1}); + builder.AddNode("Add", {reduce_mean_out, epsilon}, {add_out}); + builder.AddNode("Sqrt", {add_out}, {sqrt_out}); + builder.AddNode("Div", {input, sqrt_out}, {div_out}); + builder.AddNode("Mul", {div_out, scale}, {output}); + }; + + add_simplified_layer_norm({MLFloat16(0.5f), MLFloat16(1.0f), MLFloat16(1.5f), MLFloat16(2.0f), + MLFloat16(2.5f), MLFloat16(3.0f), MLFloat16(3.5f), MLFloat16(4.0f)}); + add_simplified_layer_norm({MLFloat16(1.0f), MLFloat16(1.25f), MLFloat16(1.5f), MLFloat16(1.75f), + MLFloat16(2.0f), MLFloat16(2.25f), MLFloat16(2.5f), MLFloat16(2.75f)}); + + // This independent branch provides FuseInitializersTransformer with single-consumer FP16 + // initializer Casts to fold at L4. With loop level 1, that L4 modification triggers the next L2 pass. + std::vector conv_input_data(25, MLFloat16(1.0f)); + auto* conv_input = builder.MakeInput({1, 1, 5, 5}, conv_input_data); + auto* conv_weight = builder.MakeInitializer( + {1, 1, 3, 3}, std::vector(9, MLFloat16(0.25f))); + auto* conv_bias = builder.MakeInitializer({1}, {MLFloat16(0.0f)}); + auto* conv_output = builder.MakeOutput(); + builder.AddNode("Conv", {conv_input, conv_weight, conv_bias}, {conv_output}); + }; + + auto op_count = [](const std::map& op_to_count, const std::string& op_type) { + const auto it = op_to_count.find(op_type); + return it == op_to_count.end() ? 0 : it->second; + }; + + auto set_graph_optimization_loop_level = [](const char* loop_level) { + return [loop_level](SessionOptions& session_options) { + ASSERT_STATUS_OK(session_options.config_options.AddConfigEntry( + kOrtSessionOptionsGraphOptimizationsLoopLevel, loop_level)); + }; + }; + + auto check_without_l2_repeat = [&](InferenceSessionWrapper& session) { + const auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_count(op_to_count, "SimplifiedLayerNormalization"), 0); + EXPECT_EQ(op_count(op_to_count, "Pow"), 2); + + int cpu_pow_count = 0; + for (const Node& node : session.GetGraph().Nodes()) { + if (node.OpType() == "Pow") { + EXPECT_EQ(node.GetExecutionProviderType(), kCpuExecutionProvider); + ++cpu_pow_count; + } + } + EXPECT_EQ(cpu_pow_count, 2); + }; + + auto check_with_l2_repeat = [&](InferenceSessionWrapper& session) { + const auto op_to_count = CountOpsInGraph(session.GetGraph()); + EXPECT_EQ(op_count(op_to_count, "SimplifiedLayerNormalization"), 2); + EXPECT_EQ(op_count(op_to_count, "Pow"), 0); + EXPECT_EQ(op_count(op_to_count, "ReduceMean"), 0); + EXPECT_EQ(op_count(op_to_count, "Add"), 0); + EXPECT_EQ(op_count(op_to_count, "Sqrt"), 0); + EXPECT_EQ(op_count(op_to_count, "Div"), 0); + EXPECT_EQ(op_count(op_to_count, "Mul"), 0); + + int cpu_simplified_layer_norm_count = 0; + int inserted_cast_count = 0; + for (const Node& node : session.GetGraph().Nodes()) { + if (node.OpType() == "SimplifiedLayerNormalization") { + EXPECT_EQ(node.GetExecutionProviderType(), kCpuExecutionProvider); + ++cpu_simplified_layer_norm_count; + } else if (node.OpType() == "Cast" && + node.Name().find("InsertedPrecisionFreeCast_") == 0) { + ++inserted_cast_count; + } + } + EXPECT_EQ(cpu_simplified_layer_norm_count, 2); + EXPECT_GE(inserted_cast_count, 2); + }; + + TransformerTester(build_test_case, check_without_l2_repeat, + TransformerLevel::Level1, TransformerLevel::MaxLevel, 17, + 1e-3, 1e-3, nullptr, set_graph_optimization_loop_level("0")); + TransformerTester(build_test_case, check_with_l2_repeat, + TransformerLevel::Level1, TransformerLevel::MaxLevel, 17, + 1e-3, 1e-3, nullptr, set_graph_optimization_loop_level("1")); +} + // It tests the scenario when scale or bias are not Graph Inputs and not initialized in Graph // To test this added a Identity node after Scale and Bias terms to ensure LayerNormFusion works properly TEST_F(GraphTransformationTests, LayerNormScaleBiasTest) { From cb08b0e0b76cf1764d5be3be906ad3a37be63fd7 Mon Sep 17 00:00:00 2001 From: Theodore Cooper Date: Tue, 23 Jun 2026 10:58:06 +0800 Subject: [PATCH 5/6] Validate SimplifiedLayerNorm Pow exponent and epsilon input --- .../core/optimizer/layer_norm_fusion.cc | 77 +++++++++++- .../graph_transform_test_layernorm.cc | 114 +++++++++++++++++- 2 files changed, 187 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 1a44ccb3b7aa3..98081b06c3f29 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -78,6 +78,74 @@ static std::vector GetAxesFromReduceMeanNode(Node& reduce_mean_node, co return axes_values; }; +static bool TryGetScalarInitializerAsDouble(const Graph& graph, const NodeArg& node_arg, double& value) { + const auto* tensor_proto = graph_utils::GetConstantInitializer(graph, node_arg.Name()); + if (tensor_proto == nullptr) { + return false; + } + + Initializer initializer{graph, *tensor_proto, graph.ModelPath()}; + if (initializer.size() != 1) { + return false; + } + + switch (tensor_proto->data_type()) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + value = static_cast(initializer.data()[0]); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + value = static_cast(initializer.data()[0]); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_DOUBLE: + value = initializer.data()[0]; + return true; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + value = static_cast(initializer.data()[0]); + return true; + default: + return false; + } +} + +static bool IsPowExponentTwo(const Graph& graph, const Node& pow_node) { + const auto& pow_inputs = pow_node.InputDefs(); + if (pow_inputs.size() < 2 || pow_inputs[1] == nullptr) { + return false; + } + + double exponent_value = 0.0; + if (TryGetScalarInitializerAsDouble(graph, *pow_inputs[1], exponent_value)) { + return exponent_value == 2.0; + } + + const Node* exponent_input_node = graph_utils::GetInputNode(pow_node, 1); + if (exponent_input_node == nullptr || + !graph_utils::IsSupportedOptypeVersionAndDomain(*exponent_input_node, "Cast", {9, 13, 19, 21, 23, 24, 25}) || + exponent_input_node->InputDefs().empty() || exponent_input_node->InputDefs()[0] == nullptr) { + return false; + } + + return TryGetScalarInitializerAsDouble(graph, *exponent_input_node->InputDefs()[0], exponent_value) && + exponent_value == 2.0; +} + +static const NodeArg* GetOtherAddInput(const Node& add_node, const NodeArg& known_input) { + const auto& add_inputs = add_node.InputDefs(); + if (add_inputs.size() < 2) { + return nullptr; + } + + if (add_inputs[0] != nullptr && add_inputs[0]->Name() == known_input.Name()) { + return add_inputs[1]; + } + + if (add_inputs[1] != nullptr && add_inputs[1]->Name() == known_input.Name()) { + return add_inputs[0]; + } + + return nullptr; +} + /** Layer Normalization will fuse LayerNormalization into one node : +---------------------+ @@ -558,7 +626,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12, 13, 15}) || !graph_utils::IsSupportedProvider(pow_node, GetCompatibleExecutionProviders()) || !optimizer_utils::CheckOutputEdges(graph, pow_node, 1) || graph.NodeProducesGraphOutput(pow_node) || - !IsSupportedDataType(pow_node)) { + !IsSupportedDataType(pow_node) || !IsPowExponentTwo(graph, pow_node)) { continue; } nodes_to_remove.push_back(pow_node); @@ -590,6 +658,11 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr } nodes_to_remove.push_back(add_node); + const NodeArg* epsilon_input = GetOtherAddInput(add_node, *reduce_mean_node.MutableOutputDefs()[0]); + if (epsilon_input == nullptr) { + continue; + } + const Node* p_sqrt = graph_utils::FirstChildByType(add_node, "Sqrt"); if (p_sqrt == nullptr) { continue; @@ -728,7 +801,7 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr // Get constant "epsilon" from "Add" node if available. Else, default value will be used. const ONNX_NAMESPACE::TensorProto* tensor_proto = - graph_utils::GetConstantInitializer(graph, add_node.MutableInputDefs()[1]->Name()); + graph_utils::GetConstantInitializer(graph, epsilon_input->Name()); if (tensor_proto != nullptr && tensor_proto->data_type() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT) { Initializer initializer{graph, *tensor_proto, graph.ModelPath()}; // epsilon must be a scalar/1-element tensor; fall back to default otherwise. diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index a7f481fd4a9a2..b9d2619443cfb 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -541,6 +541,12 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionSharedCastPowExponent) }; auto pre_graph_checker = [has_leading_cast](Graph& graph) { + if (has_leading_cast) { + for (Node& node : graph.Nodes()) { + node.SetExecutionProviderType(kCudaExecutionProvider); + } + } + const auto op_to_count = CountOpsInGraph(graph); const auto cast_it = op_to_count.find("Cast"); const int expected_cast_count = has_leading_cast ? 3 : 1; @@ -550,7 +556,7 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionSharedCastPowExponent) return Status::OK(); }; - auto post_graph_checker = [](Graph& graph) { + auto post_graph_checker = [has_leading_cast](Graph& graph) { const auto op_to_count = CountOpsInGraph(graph); const auto simplified_layer_norm_it = op_to_count.find("SimplifiedLayerNormalization"); TEST_RETURN_IF_NOT(simplified_layer_norm_it != op_to_count.end() && @@ -562,17 +568,121 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionSharedCastPowExponent) TEST_RETURN_IF_NOT(op_to_count.find("Sqrt") == op_to_count.end()); TEST_RETURN_IF_NOT(op_to_count.find("Div") == op_to_count.end()); TEST_RETURN_IF_NOT(op_to_count.find("Mul") == op_to_count.end()); + + if (has_leading_cast) { + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "SimplifiedLayerNormalization") { + TEST_RETURN_IF_NOT(node.GetExecutionProviderType() == kCudaExecutionProvider); + } + } + } + return Status::OK(); }; const InlinedHashSet compatible_eps; ASSERT_STATUS_OK(TestGraphTransformer( build_test_case, 17, *logger_, - std::make_unique(compatible_eps, has_leading_cast), + std::make_unique(compatible_eps), TransformerLevel::Level2, 1, pre_graph_checker, post_graph_checker)); } } +TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionAddEpsilonInput0) { + auto build_test_case = [](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({{2, 4}}); + auto* pow_exponent = builder.MakeInitializer({}, {2.0f}); + auto* epsilon = builder.MakeInitializer({}, {1e-4f}); + auto* scale = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + + auto* pow_out = builder.MakeIntermediate(); + auto* reduce_mean_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sqrt_out = builder.MakeIntermediate(); + auto* div_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Pow", {input, pow_exponent}, {pow_out}); + builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out}) + .AddAttribute("axes", std::vector{-1}); + builder.AddNode("Add", {epsilon, reduce_mean_out}, {add_out}); + builder.AddNode("Sqrt", {add_out}, {sqrt_out}); + builder.AddNode("Div", {input, sqrt_out}, {div_out}); + builder.AddNode("Mul", {div_out, scale}, {output}); + }; + + auto post_graph_checker = [](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + const auto simplified_layer_norm_it = op_to_count.find("SimplifiedLayerNormalization"); + TEST_RETURN_IF_NOT(simplified_layer_norm_it != op_to_count.end() && + simplified_layer_norm_it->second == 1); + + for (const Node& node : graph.Nodes()) { + if (node.OpType() == "SimplifiedLayerNormalization") { + const auto& attributes = node.GetAttributes(); + const auto epsilon_it = attributes.find("epsilon"); + TEST_RETURN_IF_NOT(epsilon_it != attributes.end()); + TEST_RETURN_IF_NOT(epsilon_it->second.f() == 1e-4f); + } + } + + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer( + build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); +} + +TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionRequiresPowExponentTwo) { + for (const bool has_cast_exponent : {false, true}) { + auto build_test_case = [has_cast_exponent](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({{2, 4}}); + NodeArg* pow_exponent = nullptr; + if (has_cast_exponent) { + auto* pow_exponent_fp16 = builder.MakeInitializer({}, {MLFloat16(3.0f)}); + pow_exponent = builder.MakeIntermediate(); + builder.AddNode("Cast", {pow_exponent_fp16}, {pow_exponent}) + .AddAttribute("to", static_cast(ONNX_NAMESPACE::TensorProto_DataType_FLOAT)); + } else { + pow_exponent = builder.MakeInitializer({}, {3.0f}); + } + + auto* epsilon = builder.MakeInitializer({}, {1e-5f}); + auto* scale = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + + auto* pow_out = builder.MakeIntermediate(); + auto* reduce_mean_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sqrt_out = builder.MakeIntermediate(); + auto* div_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Pow", {input, pow_exponent}, {pow_out}); + builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out}) + .AddAttribute("axes", std::vector{-1}); + builder.AddNode("Add", {reduce_mean_out, epsilon}, {add_out}); + builder.AddNode("Sqrt", {add_out}, {sqrt_out}); + builder.AddNode("Div", {input, sqrt_out}, {div_out}); + builder.AddNode("Mul", {div_out, scale}, {output}); + }; + + auto post_graph_checker = [](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + TEST_RETURN_IF_NOT(op_to_count.find("SimplifiedLayerNormalization") == op_to_count.end()); + const auto pow_it = op_to_count.find("Pow"); + TEST_RETURN_IF_NOT(pow_it != op_to_count.end() && pow_it->second == 1); + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer( + build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); + } +} + TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionOptimizationLoopSharedPowExponent) { if (MlasFp16AccelerationSupported()) { GTEST_SKIP() << "Skipping test because FP16 acceleration support can avoid the CPU fp16 fallback path."; From d5fc2230d3fb5dfe7e5dce77f0b790869ae51bac Mon Sep 17 00:00:00 2001 From: Theodore Cooper Date: Tue, 23 Jun 2026 20:16:24 +0800 Subject: [PATCH 6/6] Support integer Pow exponents in SimplifiedLayerNorm fusion --- .../core/optimizer/layer_norm_fusion.cc | 29 ++++++++++- .../graph_transform_test_layernorm.cc | 49 +++++++++++++++++++ 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/optimizer/layer_norm_fusion.cc b/onnxruntime/core/optimizer/layer_norm_fusion.cc index 98081b06c3f29..351c8552c0595 100644 --- a/onnxruntime/core/optimizer/layer_norm_fusion.cc +++ b/onnxruntime/core/optimizer/layer_norm_fusion.cc @@ -6,6 +6,7 @@ #include "core/optimizer/utils.h" #include "float.h" #include +#include #include using namespace ONNX_NAMESPACE; @@ -102,6 +103,30 @@ static bool TryGetScalarInitializerAsDouble(const Graph& graph, const NodeArg& n case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: value = static_cast(initializer.data()[0]); return true; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + value = static_cast(initializer.data()[0]); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_INT16: + value = static_cast(initializer.data()[0]); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + value = static_cast(initializer.data()[0]); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + value = static_cast(initializer.data()[0]); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + value = static_cast(initializer.data()[0]); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_UINT16: + value = static_cast(initializer.data()[0]); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + value = static_cast(initializer.data()[0]); + return true; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + value = static_cast(initializer.data()[0]); + return true; default: return false; } @@ -623,10 +648,12 @@ Status SimplifiedLayerNormFusion::ApplyImpl(Graph& graph, bool& modified, int gr Node& pow_node = *p_pow; ORT_RETURN_IF_ERROR(Recurse(pow_node, modified, graph_level, logger)); + // Only the Pow base/output type must be supported by SimplifiedLayerNormalization. The exponent can + // be an integer scalar 2 per the Pow schema and is validated separately. if (!graph_utils::IsSupportedOptypeVersionAndDomain(pow_node, "Pow", {7, 12, 13, 15}) || !graph_utils::IsSupportedProvider(pow_node, GetCompatibleExecutionProviders()) || !optimizer_utils::CheckOutputEdges(graph, pow_node, 1) || graph.NodeProducesGraphOutput(pow_node) || - !IsSupportedDataType(pow_node) || !IsPowExponentTwo(graph, pow_node)) { + !IsSupportedDataType(pow_node, 1) || !IsPowExponentTwo(graph, pow_node)) { continue; } nodes_to_remove.push_back(pow_node); diff --git a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc index b9d2619443cfb..1e30ebc5b5392 100644 --- a/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc +++ b/onnxruntime/test/optimizer/graph_transform_test_layernorm.cc @@ -6,6 +6,8 @@ #endif #include +#include +#include #include "gtest/gtest.h" @@ -635,6 +637,53 @@ TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionAddEpsilonInput0) { TransformerLevel::Level2, 1, nullptr, post_graph_checker)); } +TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionAllowsIntegerPowExponentTwo) { + auto run_test = [this](auto exponent_value) { + using T = std::decay_t; + + auto build_test_case = [exponent_value](ModelTestBuilder& builder) { + auto* input = builder.MakeInput({{2, 4}}); + auto* pow_exponent = builder.MakeInitializer({}, {exponent_value}); + auto* epsilon = builder.MakeInitializer({}, {1e-5f}); + auto* scale = builder.MakeInitializer({4}, {1.0f, 1.0f, 1.0f, 1.0f}); + + auto* pow_out = builder.MakeIntermediate(); + auto* reduce_mean_out = builder.MakeIntermediate(); + auto* add_out = builder.MakeIntermediate(); + auto* sqrt_out = builder.MakeIntermediate(); + auto* div_out = builder.MakeIntermediate(); + auto* output = builder.MakeOutput(); + + builder.AddNode("Pow", {input, pow_exponent}, {pow_out}); + builder.AddNode("ReduceMean", {pow_out}, {reduce_mean_out}) + .AddAttribute("axes", std::vector{-1}); + builder.AddNode("Add", {reduce_mean_out, epsilon}, {add_out}); + builder.AddNode("Sqrt", {add_out}, {sqrt_out}); + builder.AddNode("Div", {input, sqrt_out}, {div_out}); + builder.AddNode("Mul", {div_out, scale}, {output}); + }; + + auto post_graph_checker = [](Graph& graph) { + const auto op_to_count = CountOpsInGraph(graph); + const auto simplified_layer_norm_it = op_to_count.find("SimplifiedLayerNormalization"); + TEST_RETURN_IF_NOT(simplified_layer_norm_it != op_to_count.end() && + simplified_layer_norm_it->second == 1); + TEST_RETURN_IF_NOT(op_to_count.find("Pow") == op_to_count.end()); + return Status::OK(); + }; + + ASSERT_STATUS_OK(TestGraphTransformer( + build_test_case, 17, *logger_, + std::make_unique(), + TransformerLevel::Level2, 1, nullptr, post_graph_checker)); + }; + + run_test(int8_t{2}); + run_test(uint8_t{2}); + run_test(int32_t{2}); + run_test(int64_t{2}); +} + TEST_F(GraphTransformationTests, SimplifiedLayerNormFusionRequiresPowExponentTwo) { for (const bool has_cast_exponent : {false, true}) { auto build_test_case = [has_cast_exponent](ModelTestBuilder& builder) {