Skip to content

Commit 61bc23f

Browse files
committed
Fix apply_compile called multiple times in PP initialization
stack-info: PR: #2135, branch: xmfan/stack/8
1 parent fbafd44 commit 61bc23f

File tree

1 file changed

+26
-19
lines changed

1 file changed

+26
-19
lines changed

torchtitan/models/llama4/infra/parallelize.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -572,27 +572,34 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: b
572572

573573
model.layers.register_module(layer_id, transformer_block)
574574

575-
moe_module._run_experts_grouped_mm = torch.compile(
576-
moe_module._run_experts_grouped_mm,
577-
backend=compile_config.backend,
578-
fullgraph=True,
575+
# Patch some globals only once (apply_compile is called multiple times for PP setup)
576+
already_patched = (
577+
"_run_experts_grouped_mm_dynamic"
578+
in moe_module._run_experts_grouped_mm.__qualname__
579579
)
580+
if not already_patched:
581+
moe_module._run_experts_grouped_mm = torch.compile(
582+
moe_module._run_experts_grouped_mm,
583+
backend=compile_config.backend,
584+
fullgraph=True,
585+
)
580586

581-
if ep_enabled:
582-
compiled_fn = moe_module._run_experts_grouped_mm
583-
584-
def _run_experts_grouped_mm_dynamic(
585-
w1: torch.Tensor,
586-
w2: torch.Tensor,
587-
w3: torch.Tensor,
588-
x: torch.Tensor,
589-
num_tokens_per_expert: torch.Tensor,
590-
) -> torch.Tensor:
591-
# dynamic number of tokens in expert parallel
592-
torch._dynamo.mark_dynamic(x, 0)
593-
return compiled_fn(w1, w2, w3, x, num_tokens_per_expert)
594-
595-
moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic
587+
if ep_enabled:
588+
compiled_fn = moe_module._run_experts_grouped_mm
589+
590+
# keep function logic in sync with `already_patched` above
591+
def _run_experts_grouped_mm_dynamic(
592+
w1: torch.Tensor,
593+
w2: torch.Tensor,
594+
w3: torch.Tensor,
595+
x: torch.Tensor,
596+
num_tokens_per_expert: torch.Tensor,
597+
) -> torch.Tensor:
598+
# dynamic number of tokens in expert parallel
599+
torch._dynamo.mark_dynamic(x, 0)
600+
return compiled_fn(w1, w2, w3, x, num_tokens_per_expert)
601+
602+
moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic
596603

597604
# NOTE: We don't compile for loop code path due to an issue with unbacked symints:
598605
# https://github.com/pytorch/pytorch/issues/166460

0 commit comments

Comments
 (0)