Skip to content

fix(fsdp engine): localize DTensor norm output for Qwen models in TP#1365

Open
HT-Yuan wants to merge 1 commit into
areal-project:mainfrom
HT-Yuan:fix/qwen-tp-dtensor-localize
Open

fix(fsdp engine): localize DTensor norm output for Qwen models in TP#1365
HT-Yuan wants to merge 1 commit into
areal-project:mainfrom
HT-Yuan:fix/qwen-tp-dtensor-localize

Conversation

@HT-Yuan
Copy link
Copy Markdown
Collaborator

@HT-Yuan HT-Yuan commented May 25, 2026

Description

Qwen models have intermediate ops (aten.alias, aten.slice) between the final norm and lm_head that break DTensor dispatch under tensor parallelism. This commit:

  • Adds is_qwen_model() helper to identify Qwen model family.
  • Registers a forward hook on the final norm to redistribute its DTensor output to Replicate and convert to a local tensor.
  • Adjusts lm_head/score input_layouts to Replicate() for Qwen models so the downstream linear layers receive plain tensors.
  • Extracts backbone variable to avoid redundant attribute access.

Without this fix, Qwen models crash with DTensor dispatch errors when running with TP > 1.

Changes

  • areal/engine/core/model.py: Add is_qwen_model() utility function.
  • areal/engine/fsdp_utils/parallel.py:
    • Add _localize_dtensor_output hook function.
    • Conditionally set head_input_layout based on model type.
    • Register hook on final norm after parallelize_module.
    • Extract backbone for clarity and add type-checking guards.

Related Issue

Fixes #1366

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

Additional Context


Need help? Check the Contributing Guide or ask in
GitHub Discussions!

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for localizing DTensor outputs in Qwen models during tensor parallelization. It adds a _localize_dtensor_output hook and logic to register it on the model's final normalization layer, which is necessary for models where the path between the norm and the language head contains operations incompatible with DTensors. A review comment identifies a redundant type check for the model backbone that should be removed to simplify the code.

Comment thread areal/engine/fsdp_utils/parallel.py Outdated
Comment on lines +376 to +378
if not isinstance(model.model, nn.Module):
raise RuntimeError("Model does not have the required submodule 'model'.")
backbone = model.model
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check is redundant because model.model is already verified to be an instance of nn.Module at the beginning of the apply_non_moe_tp function (lines 261-262). You can directly assign backbone = model.model here.

        backbone = model.model

@HT-Yuan HT-Yuan force-pushed the fix/qwen-tp-dtensor-localize branch from 16b4658 to f1f3ae8 Compare May 25, 2026 12:18
@HT-Yuan HT-Yuan force-pushed the fix/qwen-tp-dtensor-localize branch from f1f3ae8 to 413ace9 Compare May 30, 2026 05:17
@HT-Yuan
Copy link
Copy Markdown
Collaborator Author

HT-Yuan commented May 30, 2026

It seems this bug was introduced in a PyTorch version upgrade (2.9.0). pytorch/pytorch#170427

Qwen models have intermediate ops (aten.alias, aten.slice) between
the final norm and lm_head that break DTensor dispatch under tensor
parallelism. This commit:

- Adds is_qwen_model() helper to identify Qwen model family.
- Registers a forward hook on the final norm to redistribute its
  DTensor output to Replicate and convert to a local tensor.
- Adjusts lm_head/score input_layouts to Replicate() for Qwen models
  so the downstream linear layers receive plain tensors.
- Extracts backbone variable to avoid redundant attribute access.

Without this fix, Qwen models crash with DTensor dispatch errors
when running with TP > 1.
@HT-Yuan HT-Yuan force-pushed the fix/qwen-tp-dtensor-localize branch from 413ace9 to f3d0b0f Compare June 2, 2026 13:27
@HT-Yuan
Copy link
Copy Markdown
Collaborator Author

HT-Yuan commented Jun 2, 2026

It seems this bug was introduced in a PyTorch version upgrade (2.9.0). pytorch/pytorch#170427

@sitabulaixizawaluduo
For this reason, one viable fix here is to upgrade the PyTorch version. Which solution would you prefer to proceed with?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG] Qwen models crash with DTensor dispatch error under TP > 1

2 participants