fix(fsdp engine): localize DTensor norm output for Qwen models in TP#1365
fix(fsdp engine): localize DTensor norm output for Qwen models in TP#1365HT-Yuan wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces support for localizing DTensor outputs in Qwen models during tensor parallelization. It adds a _localize_dtensor_output hook and logic to register it on the model's final normalization layer, which is necessary for models where the path between the norm and the language head contains operations incompatible with DTensors. A review comment identifies a redundant type check for the model backbone that should be removed to simplify the code.
| if not isinstance(model.model, nn.Module): | ||
| raise RuntimeError("Model does not have the required submodule 'model'.") | ||
| backbone = model.model |
16b4658 to
f1f3ae8
Compare
f1f3ae8 to
413ace9
Compare
|
It seems this bug was introduced in a PyTorch version upgrade (2.9.0). pytorch/pytorch#170427 |
Qwen models have intermediate ops (aten.alias, aten.slice) between the final norm and lm_head that break DTensor dispatch under tensor parallelism. This commit: - Adds is_qwen_model() helper to identify Qwen model family. - Registers a forward hook on the final norm to redistribute its DTensor output to Replicate and convert to a local tensor. - Adjusts lm_head/score input_layouts to Replicate() for Qwen models so the downstream linear layers receive plain tensors. - Extracts backbone variable to avoid redundant attribute access. Without this fix, Qwen models crash with DTensor dispatch errors when running with TP > 1.
413ace9 to
f3d0b0f
Compare
@sitabulaixizawaluduo |
Description
Qwen models have intermediate ops (aten.alias, aten.slice) between the final norm and lm_head that break DTensor dispatch under tensor parallelism. This commit:
Without this fix, Qwen models crash with DTensor dispatch errors when running with TP > 1.
Changes
areal/engine/core/model.py: Addis_qwen_model()utility function.areal/engine/fsdp_utils/parallel.py:_localize_dtensor_outputhook function.head_input_layoutbased on model type.parallelize_module.backbonefor clarity and add type-checking guards.Related Issue
Fixes #1366
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
Additional Context
Need help? Check the Contributing Guide or ask in
GitHub Discussions!