Optimize OSFT factorized linear and gradient projection kernels#75
Optimize OSFT factorized linear and gradient projection kernels#75RobotSail wants to merge 5 commits into
Conversation
Two key optimizations that yield ~15% end-to-end training throughput
improvement on Llama-3.1-8B-Instruct (4x H100, bf16, rank_ratio=0.5):
1. Factorized linear (_factorized_linear):
- Flatten input to 2D and use torch.mm instead of batched @ operator
- Replace separate `result_high + result_low` addition with
`result.addmm_(tmp_low, U_low.T)` which fuses the low-rank
matmul and addition into a single cuBLAS call
- Eliminates one kernel launch per OSFT target per forward pass
2. Gradient projection (project_gradient_to_orthogonal_space):
- Replace Gram matrix form `G = V_high^T @ V_high; dV -= dV @ G`
with factored form `dV -= (dV @ V_high^T) @ V_high`
- Avoids materializing the (K, K) Gram matrix (e.g. 4096x4096 for
Llama), replacing it with a small (rank_low, rank_high) intermediate
- Fuse subtraction into matmul via `addmm_(alpha=-1.0)`
- The all-reduce now operates on the smaller (rank_low, rank_high)
tensor instead of (K, K), reducing NCCL communication volume
Also includes transformers v5 compatibility fixes:
- Pass both `torch_dtype` and `dtype` kwargs to from_pretrained
- Handle renamed config attribute (torch_dtype -> dtype)
- Fix dtype validation for FSDP2 mixed precision (params stored
in fp32, cast to bf16 for compute)
- Fix optimizer state validation (always fp32 for stability)
Benchmark (Llama-3.1-8B-Instruct, 4x H100 80GB, bf16, OSFT r=0.5):
Baseline: 12,232 tok/s mean, 12,766 tok/s median
Optimized: 14,080 tok/s mean, 14,385 tok/s median
Speedup: +15.1% mean, +12.7% median
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
📝 WalkthroughWalkthroughReplaces Gram-matrix V projection with a factored, batched all-reduce projection; consolidates per-module V projections into one flattened all-reduce and local apply; rewrites factorized linear forward to 2D addmm_ fusion; adds transformers v5-compatible dtype handling and relaxes training dtype validation to accept fp32 optimizer/gradients. Changes
Sequence Diagram(s)sequenceDiagram
participant Module as Module (each rank)
participant PG as AllReduce / ProcessGroup
participant Local as Local Apply
Module->>Module: compute local dV and local V_high
Module->>Module: compute coeff dV_Vt = dV @ V_high^T
Module->>PG: send flattened coeffs (batched for all modules)
PG->>PG: all-reduce (SUM) across ranks
PG->>Module: reduced coeff slice for this module
Module->>Local: local_dV.addmm_(reduced_coeff, local_V_high, alpha=-1.0)
Local->>Module: updated projected gradients
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
src/mini_trainer/train.py (1)
56-64: Relaxed validation logic may mask unexpected dtype issues.The new logic only raises if dtype is not
expected_param_dtypeAND nottorch.float32. This means:
- If
expected_param_dtype=torch.bfloat16, bothbf16andfp32pass silently- If a param unexpectedly becomes
fp16when expectingbf16, it still raises (correct)However, the nested condition structure is a bit confusing. Consider a clearer formulation:
♻️ Suggested clarification
if param.requires_grad and param.dtype != expected_param_dtype: - if param.dtype != torch.float32: - raise ValueError(f"Parameter {name} is not in {expected_param_dtype}, got {param.dtype}") + # FSDP2 MixedPrecisionPolicy may store params in fp32; allow this as valid + allowed_dtypes = {expected_param_dtype, torch.float32} + if param.dtype not in allowed_dtypes: + raise ValueError(f"Parameter {name} has unexpected dtype {param.dtype}, expected one of {allowed_dtypes}") if param.grad is not None and param.grad.dtype != expected_param_dtype: - if param.grad.dtype != torch.float32: - raise ValueError(f"Gradient {name} is not in {expected_param_dtype}, got {param.grad.dtype}") + allowed_dtypes = {expected_param_dtype, torch.float32} + if param.grad.dtype not in allowed_dtypes: + raise ValueError(f"Gradient {name} has unexpected dtype {param.grad.dtype}, expected one of {allowed_dtypes}")🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/mini_trainer/train.py` around lines 56 - 64, The current nested checks around param.requires_grad, param.dtype, expected_param_dtype and torch.float32 are confusing and can silently allow unintended dtypes; update the validation in train.py so each param (and param.grad) is explicitly allowed only if its dtype equals expected_param_dtype OR equals torch.float32 (to accommodate FSDP storage), otherwise raise a ValueError referencing the parameter name; specifically replace the two nested if-blocks that check param.dtype and param.grad.dtype with clear single-condition checks using (param.dtype != expected_param_dtype and param.dtype != torch.float32) and similarly for (param.grad.dtype != expected_param_dtype and param.grad.dtype != torch.float32) for the same named parameter checks.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@src/mini_trainer/train.py`:
- Around line 56-64: The current nested checks around param.requires_grad,
param.dtype, expected_param_dtype and torch.float32 are confusing and can
silently allow unintended dtypes; update the validation in train.py so each
param (and param.grad) is explicitly allowed only if its dtype equals
expected_param_dtype OR equals torch.float32 (to accommodate FSDP storage),
otherwise raise a ValueError referencing the parameter name; specifically
replace the two nested if-blocks that check param.dtype and param.grad.dtype
with clear single-condition checks using (param.dtype != expected_param_dtype
and param.dtype != torch.float32) and similarly for (param.grad.dtype !=
expected_param_dtype and param.grad.dtype != torch.float32) for the same named
parameter checks.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 924be8a6-aaf0-4584-89c4-d5352a75724c
📒 Files selected for processing (3)
src/mini_trainer/osft_utils.pysrc/mini_trainer/setup_model_for_training.pysrc/mini_trainer/train.py
Same pattern as the existing U projection batching: collect all (dV @ V_high^T) coefficients across 224 OSFT targets, concatenate into a single flat tensor, perform one all-reduce, then split back and apply corrections. This reduces V projection from 224 all-reduce launches to 1, cutting NCCL collective launch overhead. Benchmark improvement: +2.6% on top of factored V projection. Total vs baseline: +17.9% mean throughput. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/mini_trainer/osft_utils.py (1)
539-544:⚠️ Potential issue | 🟡 MinorDocstring contradicts the new factored implementation.
The docstring at lines 540-544 states that the V projection "must use the Gram-matrix form" because the factored form "produce[s] column-blocks rather than partial sums — requiring an all-gather instead of an all-reduce." However, the implementation below (lines 591-598) now uses the factored form with
all_reduce.The factored form is correct here: when FSDP2 shards V_high on dim-0,
local_dV @ local_V_high.Tproduces partial sums that all-reduce correctly aggregates. Please update the docstring to reflect the new approach.📝 Suggested docstring update
- V projection must use the Gram-matrix form - dV -= dV @ (V_high^T @ V_high) because FSDP2 shards V_high on dim-0 - (the singular-vector dimension), making the factored form - dV -= (dV @ V_high^T) @ V_high produce column-blocks rather than - partial sums — requiring an all-gather instead of an all-reduce. + V projection uses the factored form dV -= (dV @ V_high^T) @ V_high + with a small (rank_low, rank_high) intermediate. When FSDP2 shards + V_high on dim-0, `dV @ V_high^T` produces partial sums that are + correctly aggregated via all-reduce, then multiplied by local V_high + rows to complete the projection.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@src/mini_trainer/osft_utils.py` around lines 539 - 544, Update the docstring for the V projection to match the implemented factored form: replace the claim that the Gram-matrix form is required and that the factored form would produce column-blocks requiring an all-gather; instead state that because FSDP2 shards V_high on dim-0, the local computation (local_dV @ local_V_high.T) produces partial sums which are correctly aggregated via all_reduce, so the code uses the factored form with all_reduce (referencing variables dV, V_high, local_dV, local_V_high and the V projection implementation).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/mini_trainer/osft_utils.py`:
- Around line 2148-2172: The non-distributed else branch under the "if
v_flat_parts:" check is dead code because non-distributed paths return earlier;
remove the unreachable block (lines that iterate over v_work using
v_flat_parts.pop and dV.copy_) and either replace it with an assertion like
"assert is_distributed" or rely on the surrounding logic to guarantee
is_distributed is True; update any tests or callers if they depended on that
branch. Ensure you only modify the branch inside the v_flat_parts handling
(references: v_flat_parts, is_distributed, v_work, local_V_high, local_dV, dV,
coeff_shape) to eliminate the redundant code path.
---
Outside diff comments:
In `@src/mini_trainer/osft_utils.py`:
- Around line 539-544: Update the docstring for the V projection to match the
implemented factored form: replace the claim that the Gram-matrix form is
required and that the factored form would produce column-blocks requiring an
all-gather; instead state that because FSDP2 shards V_high on dim-0, the local
computation (local_dV @ local_V_high.T) produces partial sums which are
correctly aggregated via all_reduce, so the code uses the factored form with
all_reduce (referencing variables dV, V_high, local_dV, local_V_high and the V
projection implementation).
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 175023e5-92a8-4b20-9a19-4258601a5b86
📒 Files selected for processing (1)
src/mini_trainer/osft_utils.py
|
@claude fix these CI issues |
|
Claude finished @RobotSail's task in 5m 53s —— View job Fixing CI Issues
Note on remaining test failuresThe 4 |
- Fix ruff format: add spaces around slice in batched_v[offset : offset + numel] - Update stale docstring: V projection now correctly documents the factored form with all-reduce instead of the old Gram-matrix claim - Remove dead code: non-distributed else branch in v_flat_parts handling is unreachable since non-distributed paths return early at line 2074 - Clarify dtype validation: use allowed_dtypes set instead of nested conditionals for FSDP2 mixed precision param/grad checking Co-authored-by: Oleg Silkin <RobotSail@users.noreply.github.com> Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
@claude-code it looks like there are still some merge conflicts with the model setup script. Maybe you need to rebase this |
|
Claude finished @RobotSail's task in 2m 52s —— View job Resolving Merge Conflicts
Conflict Resolution DetailsAll three conflicts in
Merged main into the branch via | Branch |
Resolves merge conflicts in setup_model_for_training.py by combining: - trust_remote_code support from main (#78, #81) - transformers v5 dtype compatibility from this branch All three conflicts resolved by keeping both sets of changes: 1. get_model_save_dtype: trust_remote_code param + dtype fallback 2. base_model_args: dtype key + trust_remote_code key 3. save_dtype assignment: trust_remote_code passthrough + v5 dtype handling Co-authored-by: Oleg Silkin <RobotSail@users.noreply.github.com>
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
|
@claude-code this has merge conflicts with the main branch, can you please figure out why and fix it if possible? Push your fixes directly back into this branch. Thanks! |
|
Claude finished @RobotSail's task in 10m 13s —— View job Resolving Merge Conflicts
Conflict DetailsThe conflicts arose because PR #74 ("Replace Gram matrix V projection with factored form") landed on main while this branch uses a different optimization strategy. Both avoid the
Resolution: kept this PR's all-reduce approach (benchmarked at +17.9% throughput). Specifically:
| | Branch |
Conflicts arose from PR #74 (factored V projection with all-gather+caching) landing on main while this branch uses a different approach (all-reduce of small coefficients + batched V all-reduces). Resolved by keeping this branch's approach which benchmarks at +17.9% throughput: - V projection: all-reduce of (rank_low, rank_high) coefficients instead of all-gathering V_high (avoids storing full V_high per rank) - Batched V projection: single NCCL all-reduce for all 224 targets instead of per-module all-gathers - Removed unused math import and OSFT_CACHE_V references (no longer needed without all-gather approach) - Updated tests: kept factored-vs-gram correctness tests, removed cache/all-gather specific tests from PR #74 Co-authored-by: Oleg Silkin <RobotSail@users.noreply.github.com>
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
tests/test_osft.py (1)
1594-1595: Mark this new suite as a unit test.These are CPU-only correctness tests, but the class is unmarked, so marker-based runs can't include/exclude it consistently.
🏷️ Suggested tweak
+@pytest.mark.unit class TestVProjectionFactored: """Test factored V projection correctness (dV -= (dV @ V_high^T) @ V_high)."""As per coding guidelines,
{tests,regression_tests}/**/*.py: Use pytest for testing with markers: slow, integration, unit, gpu, multi_gpu.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/test_osft.py` around lines 1594 - 1595, Add the pytest unit marker to the new test suite: ensure pytest is imported in the file (import pytest) if not already, and annotate the TestVProjectionFactored class with `@pytest.mark.unit` above its definition so the suite is recognized as a CPU-only unit test (reference the TestVProjectionFactored class name and the pytest.mark.unit decorator).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@src/mini_trainer/osft_utils.py`:
- Around line 607-614: The V-projection is wrong for FSDP2 Shard(0) layouts
because computing dV_Vt = local_dV @ local_V_high.T and then all_reduce mixes
coefficient blocks from different ranks; fix by restoring a layout-aware gather:
either all-gather the rows of V_high across ranks before computing dV_Vt (so
dV_Vt = local_dV @ gathered_V_high.T uses the global V_high rows) or instead
all-gather/reassemble the coefficient blocks dV_Vt into their correct global
column positions before applying the correction; update the code around
local_dV, local_V_high and the projection step (where dV_Vt is computed and
reduced, and local_dV.addmm_ is applied) to perform one of these gather-based
approaches rather than the plain all_reduce.
---
Nitpick comments:
In `@tests/test_osft.py`:
- Around line 1594-1595: Add the pytest unit marker to the new test suite:
ensure pytest is imported in the file (import pytest) if not already, and
annotate the TestVProjectionFactored class with `@pytest.mark.unit` above its
definition so the suite is recognized as a CPU-only unit test (reference the
TestVProjectionFactored class name and the pytest.mark.unit decorator).
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c4d0559a-ca2d-4e5e-9436-cac0dc8bc8bc
📒 Files selected for processing (4)
src/mini_trainer/osft_utils.pysrc/mini_trainer/setup_model_for_training.pysrc/mini_trainer/train.pytests/test_osft.py
🚧 Files skipped from review as they are similar to previous changes (1)
- src/mini_trainer/setup_model_for_training.py
| # Factored projection: dV -= (dV @ V_high^T) @ V_high | ||
| # Step 1: dV_Vt = dV @ V_high^T → (rank_low, rank_high) — small intermediate | ||
| dV_Vt = torch.mm(local_dV, local_V_high.transpose(0, 1)) | ||
| if dist.is_initialized() and dist.get_world_size() > 1: | ||
| dist.all_reduce(dV_Vt, op=dist.ReduceOp.SUM) | ||
|
|
||
| # Two local matmuls — no (M, M) intermediate | ||
| coeff = torch.mm(local_dV, V_high_full.transpose(0, 1)) # (k_low/P, k_high) | ||
| local_dV.addmm_(coeff, V_high_full, alpha=-1.0) # (k_low/P, M) | ||
| # Step 2: dV -= dV_Vt @ V_high — uses addmm_ to fuse subtraction | ||
| local_dV.addmm_(dV_Vt, local_V_high, alpha=-1.0) |
There was a problem hiding this comment.
The new V projection is incorrect for FSDP2 Shard(0) tensors.
local_dV @ local_V_high.T is not a partial sum in this layout; it produces coefficients for each rank's local V_high rows. all_reduce(SUM) therefore mixes unrelated coefficient blocks from different ranks instead of reconstructing the global (rank_low, rank_high) projection, so V_low.grad is wrong on real sharded DTensors. The new tests still pass because they only exercise replicated tensors, not the row-sharded case.
🧭 Fix direction
Restore a layout-aware gather for the V basis/coefficient blocks here. all_reduce is valid for the U path because the contracted dimension is sharded; for V, with the current Shard(0) layout, you need either:
- an all-gather of
V_highrows before computingdV @ V_high.T, or - an all-gather/reassembly of coefficient blocks into their correct global column positions before applying the correction.
Also applies to: 2159-2187
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@src/mini_trainer/osft_utils.py` around lines 607 - 614, The V-projection is
wrong for FSDP2 Shard(0) layouts because computing dV_Vt = local_dV @
local_V_high.T and then all_reduce mixes coefficient blocks from different
ranks; fix by restoring a layout-aware gather: either all-gather the rows of
V_high across ranks before computing dV_Vt (so dV_Vt = local_dV @
gathered_V_high.T uses the global V_high rows) or instead all-gather/reassemble
the coefficient blocks dV_Vt into their correct global column positions before
applying the correction; update the code around local_dV, local_V_high and the
projection step (where dV_Vt is computed and reduced, and local_dV.addmm_ is
applied) to perform one of these gather-based approaches rather than the plain
all_reduce.
Summary
addmm_call, flatten to 2D fortorch.mm(K, K)Gram matrix with factored(rank_low, rank_high)intermediate, fuse subtraction viaaddmm_torch_dtype→dtyperename, fix mixed-precision dtype validationBenchmark
Llama-3.1-8B-Instruct, 4x H100 80GB, bf16, OSFT
rank_ratio=0.5,batch_size=32:Dataset: 1,000 samples from UltraChat-200k, tokenized with Llama-3.1 chat template, median 1,118 tokens.
Details
1. Factorized linear (
_factorized_linear) — +11%Before: 4 separate matmuls + element-wise addition
After: flatten to 2D, 3
torch.mm+ 1addmm_2. Gradient projection (
project_gradient_to_orthogonal_space) — +4%Before: materializes
(K, K)Gram matrix (e.g. 4096×4096 for Llama)After: factored form with
(rank_low, rank_high)intermediate3. Batched V projection all-reduces — +2.6%
Before: 224 separate
dist.all_reduce()calls for V projection coefficientsAfter: single batched
dist.all_reduce()of concatenated coefficients (same pattern as existing U projection batching from PR #72)4. Transformers v5 compatibility
torch_dtypeanddtypekwargs tofrom_pretrainedtorch_dtype→dtype)Test plan
🤖 Generated with Claude Code
Summary by CodeRabbit
Improvements
Tests