Skip to content

Conversation

@fegin
Copy link
Contributor

@fegin fegin commented Dec 11, 2025

Stack from ghstack (oldest at bottom):

Summary

  1. Refactored CP Dispatching:
  • New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call.
  • Enables CP dispatcher for SDPA attention type inside apply_cp()
  1. New CP Data Sharding Approach:
  • Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API
  • Uses _HeadTailLoadBalancer for SDPA attention load balancing
  • FlexAttention CP support deferred to a future PR
  • CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded
  • The new positions argument allows us to not shard the freqs_cis.

Note that this PR require pytorch/pytorch#170200

Test

-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal

pick 5903566a Improve the loss_compare.sh logic

[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================

[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step    Baseline Loss    Test Loss   Difference
[LOSS_COMPARE] ----    -------------    ---------   ----------
[LOSS_COMPARE] 1       8.1309           8.1309           0.000000
[LOSS_COMPARE] 2       7.8268           7.8268           0.000000
[LOSS_COMPARE] 3       7.2284           7.2284           0.000000
[LOSS_COMPARE] 4       6.4669           6.4669           0.000000
[LOSS_COMPARE] 5       5.4017           5.4017           0.000000
[LOSS_COMPARE] 6       4.7656           4.7656           0.000000
[LOSS_COMPARE] 7       4.3587           4.3587           0.000000
[LOSS_COMPARE] 8       4.0938           4.0938           0.000000
[LOSS_COMPARE] 9       4.4019           4.4019           0.000000
[LOSS_COMPARE] 10      3.7451           3.7451           0.000000
....
[LOSS_COMPARE] 90      2.802            2.802            0.000000
[LOSS_COMPARE] 91      2.7207           2.7207           0.000000
[LOSS_COMPARE] 92      2.7454           2.7454           0.000000
[LOSS_COMPARE] 93      2.6992           2.6992           0.000000
[LOSS_COMPARE] 94      2.743            2.743            0.000000
[LOSS_COMPARE] 95      2.7534           2.7534           0.000000
[LOSS_COMPARE] 96      2.8403           2.8403           0.000000
[LOSS_COMPARE] 97      2.783            2.783            0.000000
[LOSS_COMPARE] 98      3.0892           3.0892           0.000000
[LOSS_COMPARE] 99      2.7905           2.7905           0.000000
[LOSS_COMPARE] 100     2.733            2.733            0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss:  3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference:     0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).

TODO

[ghstack-poisoned]
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Dec 11, 2025
@fegin fegin requested a review from acisseJZhong December 11, 2025 21:09
[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
logger.info("Applied DDP to the model")


def apply_cp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. should we put this function to distributed/context_parallel.py?
  2. should we apply this to all models?

Copy link
Contributor Author

@fegin fegin Dec 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Yes.

2. Do we actually verify CP for all models? I think llama3 and llama4, yes. I'm thinking to re-enable CP model by model using this refactor chance.
nvm, I added to the core models. I'll leave Flux for another PR. c.c, @wwwjn

"attention_norm": SequenceParallel(),
# NOTE: when the fourth argument (positions) is not None, its input layout
# and desired input layout should be Replicate()
# and desired input layout is still None as we don't convert freqs_cis to
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe should change this after @wwwjn 's PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR #2149 has composability issue when TP + PP is applied, and I'm trying to discuss how to fix. I guess we could also land this PR if it's ready and I could rebase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion. Let's land whichever PR ready first.

input_dict, labels
)
# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

after this PR, we seem no longer needing freqs_cis as model input. IIRC we modified the freqs_cis-related model code logic previous to make it model input. Shall we revert those logic?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now use max_seq_len and reassign max_seq_len during mode_args initialization. I think this doesn't count as a hack. So I only remove the legacy TODO (which I don't think we need anymore).

fegin added a commit that referenced this pull request Dec 15, 2025
Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* #2145
* #2144
* __->__ #2143

1. Accept one "." (meaning the current commit) case to simplify the
command line.
2. Ignore the untracked files.
**Summary**

1. Refactored CP Dispatching:
  - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call.
  - Enables CP dispatcher for SDPA attention type inside apply_cp()
2. New CP Data Sharding Approach:
  - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API
  - Uses _HeadTailLoadBalancer for SDPA attention load balancing
  - FlexAttention CP support deferred to a future PR
  - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded
  - The new positions argument allows us to not shard the freqs_cis.

Note that this PR require pytorch/pytorch#170200

**Test**
```
-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal

pick 5903566a Improve the loss_compare.sh logic

[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================

[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step    Baseline Loss    Test Loss   Difference
[LOSS_COMPARE] ----    -------------    ---------   ----------
[LOSS_COMPARE] 1       8.1309           8.1309           0.000000
[LOSS_COMPARE] 2       7.8268           7.8268           0.000000
[LOSS_COMPARE] 3       7.2284           7.2284           0.000000
[LOSS_COMPARE] 4       6.4669           6.4669           0.000000
[LOSS_COMPARE] 5       5.4017           5.4017           0.000000
[LOSS_COMPARE] 6       4.7656           4.7656           0.000000
[LOSS_COMPARE] 7       4.3587           4.3587           0.000000
[LOSS_COMPARE] 8       4.0938           4.0938           0.000000
[LOSS_COMPARE] 9       4.4019           4.4019           0.000000
[LOSS_COMPARE] 10      3.7451           3.7451           0.000000
....
[LOSS_COMPARE] 90      2.802            2.802            0.000000
[LOSS_COMPARE] 91      2.7207           2.7207           0.000000
[LOSS_COMPARE] 92      2.7454           2.7454           0.000000
[LOSS_COMPARE] 93      2.6992           2.6992           0.000000
[LOSS_COMPARE] 94      2.743            2.743            0.000000
[LOSS_COMPARE] 95      2.7534           2.7534           0.000000
[LOSS_COMPARE] 96      2.8403           2.8403           0.000000
[LOSS_COMPARE] 97      2.783            2.783            0.000000
[LOSS_COMPARE] 98      3.0892           3.0892           0.000000
[LOSS_COMPARE] 99      2.7905           2.7905           0.000000
[LOSS_COMPARE] 100     2.733            2.733            0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss:  3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference:     0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).
```

**TODO**
  - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id.

[ghstack-poisoned]
**Summary**

1. Refactored CP Dispatching:
  - New apply_cp() function uses PyTorch's _ContextParallel parallelization plan to dispatch attention call.
  - Enables CP dispatcher for SDPA attention type inside apply_cp()
2. New CP Data Sharding Approach:
  - Added a cp_shard() helper function that wraps PyTorch's _context_parallel_shard API
  - Uses _HeadTailLoadBalancer for SDPA attention load balancing
  - FlexAttention CP support deferred to a future PR
  - CP sharding now happens explicitly in post_dataloading_process() where inputs, labels, and positions are sharded
  - The new positions argument allows us to not shard the freqs_cis.

Note that this PR require pytorch/pytorch#170200

**Test**
```
-> % python3 scripts/loss_compare.py . chienchin/loss_compare --baseline-options="--parallelism.context_parallel_degree=8" --test-options="--parallelism.context_parallel_degree=8" --steps=100 --assert-equal

pick 5903566a Improve the loss_compare.sh logic

[LOSS_COMPARE]
[LOSS_COMPARE] Asserting losses are equal...
[LOSS_COMPARE] Baseline log: /tmp/baseline_training.log
[LOSS_COMPARE] Test log: /tmp/test_training.log
[LOSS_COMPARE] Extracted 100 steps from baseline log
[LOSS_COMPARE] Extracted 100 steps from test log
test_losses_equal
(__main__.assert_losses_equal.<locals>.LossEqualityTest.test_losses_equal)
... ok

----------------------------------------------------------------------
Ran 1 test in 0.000s

OK
[LOSS_COMPARE] All losses are equal. Assertion passed!
[LOSS_COMPARE] ==========================================
[LOSS_COMPARE] LOSS COMPARISON ANALYSIS
[LOSS_COMPARE] ==========================================

[LOSS_COMPARE] Step-by-step loss comparison:
[LOSS_COMPARE] Step    Baseline Loss    Test Loss   Difference
[LOSS_COMPARE] ----    -------------    ---------   ----------
[LOSS_COMPARE] 1       8.1309           8.1309           0.000000
[LOSS_COMPARE] 2       7.8268           7.8268           0.000000
[LOSS_COMPARE] 3       7.2284           7.2284           0.000000
[LOSS_COMPARE] 4       6.4669           6.4669           0.000000
[LOSS_COMPARE] 5       5.4017           5.4017           0.000000
[LOSS_COMPARE] 6       4.7656           4.7656           0.000000
[LOSS_COMPARE] 7       4.3587           4.3587           0.000000
[LOSS_COMPARE] 8       4.0938           4.0938           0.000000
[LOSS_COMPARE] 9       4.4019           4.4019           0.000000
[LOSS_COMPARE] 10      3.7451           3.7451           0.000000
....
[LOSS_COMPARE] 90      2.802            2.802            0.000000
[LOSS_COMPARE] 91      2.7207           2.7207           0.000000
[LOSS_COMPARE] 92      2.7454           2.7454           0.000000
[LOSS_COMPARE] 93      2.6992           2.6992           0.000000
[LOSS_COMPARE] 94      2.743            2.743            0.000000
[LOSS_COMPARE] 95      2.7534           2.7534           0.000000
[LOSS_COMPARE] 96      2.8403           2.8403           0.000000
[LOSS_COMPARE] 97      2.783            2.783            0.000000
[LOSS_COMPARE] 98      3.0892           3.0892           0.000000
[LOSS_COMPARE] 99      2.7905           2.7905           0.000000
[LOSS_COMPARE] 100     2.733            2.733            0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Summary statistics:
[LOSS_COMPARE] Average baseline loss:  3.1414940000000002
[LOSS_COMPARE] Average test loss: 3.1414940000000002
[LOSS_COMPARE] Average difference:     0.000000
[LOSS_COMPARE]
[LOSS_COMPARE] Loss comparison complete. No results saved (no output
folder specified).
```

**TODO**
  - This PR will invalidate torch.compile + CP due to pytorch/pytorch#170110. We will have to wait for Dynamo to fix the issue or refactor nn.Module core logic to avoid check hook_id.

[ghstack-poisoned]
[ghstack-poisoned]
[ghstack-poisoned]
@fegin fegin requested a review from tianyu-l December 16, 2025 07:32
self,
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
device: torch.device,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just use input_dict["inputs"].device?

def post_dataloading_process(
self,
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we should consolidate this into input_dict["labels"]

A tuple of (inputs, labels, extra_inputs, extra_kwargs) where:
- inputs: Main input tensor extracted from input_dict["input"].
- labels: Target labels (potentially modified by CP sharding).
- extra_inputs: Dict of auxiliary input tensors (all keys except
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can consolidate this into input_dict as well?

# extra_kwargs are.
extra_kwargs: dict[str, Any] = {}

attn_type = getattr(self.model_args, "attn_type", "sdpa")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems we introduce model_args to validator because of this line.

This condition is checked again in get_attention_masks. Can we remove it here? We can return None instead of throw in get_attention_masks when attn_type is sdpa.

The argument is that validator is not supposed to know ModelArgs details.

return total_norm


def cp_shard(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this function also go to context_parallel.py? maybe we can consolidate with get_context_parallel_inputs since both are not long.

from torchtitan.tools.logging import logger


def apply_cp(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name sounds too generic for what it's assuming, e.g. you needed a different one for flux.

Maybe apply_cp_to_transformer_blocks and send in model.layers.values()?

Args:
model: The transformer model with layers containing attention modules
cp_mesh: Device mesh for context parallel dimension
use_flex_attn: Whether the model uses FlexAttention (True) or SDPA (False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n00b q: We don't consider varlen + CP here? why is that?

- Applies to transformer_block.attention.inner_attention for each layer
"""
# Apply context parallelism to every transformer block
# TODO: make seq_sim configurable once the implementation doesn't assume 2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# TODO: make seq_sim configurable once the implementation doesn't assume 2
# TODO: make seq_dim configurable once the implementation doesn't assume 2

else:
# This is currently required as DTensor dispatcher is not enabled to
# dispatch SDPA to CP implementation. We don't disable the CP
# dispatching in TorchTitan as it is not needed. But there is a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment is a little bit confusing to me - DTensor dispatcher is not enabled to dispatch SDPA to CP implementation, so we explicitly enable it by calling _enable_context_parallel_dispatcher. But why "we don't disable the CP dispatching in torchtitan"?

)

if parallel_dims.cp_enabled:
use_flex_attn = attn_type == "flex"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for now, it attn_type == "varlen" will fall to SDPA branch in apply_cp?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/8gpu CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants