@@ -572,27 +572,34 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b
572572
573573 model .layers .register_module (layer_id , transformer_block )
574574
575- moe_module . _run_experts_grouped_mm = torch . compile (
576- moe_module . _run_experts_grouped_mm ,
577- backend = compile_config . backend ,
578- fullgraph = True ,
575+ # Patch some globals only once (apply_compile is called multiple times for PP setup)
576+ already_patched = (
577+ "_run_experts_grouped_mm_dynamic"
578+ in moe_module . _run_experts_grouped_mm . __qualname__
579579 )
580+ if not already_patched :
581+ moe_module ._run_experts_grouped_mm = torch .compile (
582+ moe_module ._run_experts_grouped_mm ,
583+ backend = compile_config .backend ,
584+ fullgraph = True ,
585+ )
580586
581- if ep_enabled :
582- compiled_fn = moe_module ._run_experts_grouped_mm
583-
584- def _run_experts_grouped_mm_dynamic (
585- w1 : torch .Tensor ,
586- w2 : torch .Tensor ,
587- w3 : torch .Tensor ,
588- x : torch .Tensor ,
589- num_tokens_per_expert : torch .Tensor ,
590- ) -> torch .Tensor :
591- # dynamic number of tokens in expert parallel
592- torch ._dynamo .mark_dynamic (x , 0 )
593- return compiled_fn (w1 , w2 , w3 , x , num_tokens_per_expert )
594-
595- moe_module ._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic
587+ if ep_enabled :
588+ compiled_fn = moe_module ._run_experts_grouped_mm
589+
590+ # keep function logic in sync with `already_patched` above
591+ def _run_experts_grouped_mm_dynamic (
592+ w1 : torch .Tensor ,
593+ w2 : torch .Tensor ,
594+ w3 : torch .Tensor ,
595+ x : torch .Tensor ,
596+ num_tokens_per_expert : torch .Tensor ,
597+ ) -> torch .Tensor :
598+ # dynamic number of tokens in expert parallel
599+ torch ._dynamo .mark_dynamic (x , 0 )
600+ return compiled_fn (w1 , w2 , w3 , x , num_tokens_per_expert )
601+
602+ moe_module ._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic
596603
597604 # NOTE: We don't compile for loop code path due to an issue with unbacked symints:
598605 # https://github.com/pytorch/pytorch/issues/166460
0 commit comments