Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 73 additions & 113 deletions src/mini_trainer/osft_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import gc
import math
import os
import types
import typing as t
Expand Down Expand Up @@ -553,20 +552,17 @@ def project_gradient_to_orthogonal_space(
U_high^T @ dU contracts over the sharded dim → partial sum → all-reduce
of a small (k_high, k_low) matrix.

V_high is dim-0 sharded along k_high (the small dimension), so the
factored form requires an all-gather of V_high to get the full
(k_high, M) tensor. This is M/k_high fewer bytes than the Gram
matrix all-reduce (k_high × M vs M × M) — 2x for square weights,
7x for down_proj where k_high = min(N, M) × (1 - URR).
V projection uses the factored form dV -= (dV @ V_high^T) @ V_high
with a small (rank_low, rank_high) intermediate. When FSDP2 shards
V_high on dim-0, `dV @ V_high^T` produces partial sums that are
correctly aggregated via all-reduce, then multiplied by local V_high
rows to complete the projection.

Args:
svd_dict: Dictionary containing the SVD decomposition components.
skip_u: If True, skip U projection (caller handles it externally,
e.g. via batched all-reduce in distributed mode).
cache_holder: Optional module on which to cache the all-gathered
V_high. V_high is frozen, so the cache is exact. When provided
and OSFT_CACHE_V is enabled, V_high is all-gathered once and
reused on subsequent steps.
cache_holder: Unused, kept for backwards compatibility.

TODO(osilkin): Add mixed-precision gradients here
"""
Expand Down Expand Up @@ -599,69 +595,23 @@ def project_gradient_to_orthogonal_space(
else:
dU.copy_(local_dU)

# Project V_low gradients: dV -= (dV @ V_high^T) @ V_high
# All-gather V_high from FSDP2 shards (or use cache) — see docstring for cost analysis.
# Project V_low gradients to space orthogonal to row(V_high).
# V_high has shape (k, M) with orthonormal rows (from SVD).
# Factored form: dV -= (dV @ V_high^T) @ V_high
# This avoids materializing the (M, M) Gram matrix G = V_high^T @ V_high.
if svd_dict["V_low"].grad is not None:
dV = svd_dict["V_low"].grad
local_V_high = getattr(V_high, "to_local", lambda: V_high)()
local_dV = getattr(dV, "to_local", lambda: dV)()

# V_high is frozen — reuse cached all-gathered tensor when available.
can_cache = OSFT_CACHE_V and cache_holder is not None
cached = getattr(cache_holder, "_osft_v_high_full", None) if can_cache else None

if cached is not None:
V_high_full = cached
else:
if dist.is_initialized() and dist.get_world_size() > 1:
full_k_high = svd_dict["rank_high"]
if local_V_high.shape[0] < full_k_high:
# FSDP2-sharded: all-gather V_high from all ranks.
world_size = dist.get_world_size()
remainder = full_k_high % world_size
if remainder == 0:
# Even split — direct gather, no padding needed.
V_high_full = torch.empty(
full_k_high,
local_V_high.shape[1],
dtype=local_V_high.dtype,
device=local_V_high.device,
)
dist.all_gather_into_tensor(V_high_full, local_V_high)
else:
# Uneven split — DTensor Shard(0) uses torch.chunk
# semantics: only the last shard is short, so padding
# lands at the tail of the gathered buffer. Pad to
# ceil rows for all_gather, then slice off padding.
rows_per_rank = math.ceil(full_k_high / world_size)
padded = torch.zeros(
rows_per_rank,
local_V_high.shape[1],
dtype=local_V_high.dtype,
device=local_V_high.device,
)
padded[: local_V_high.shape[0]].copy_(local_V_high)
gathered = torch.empty(
rows_per_rank * world_size,
local_V_high.shape[1],
dtype=local_V_high.dtype,
device=local_V_high.device,
)
dist.all_gather_into_tensor(gathered, padded)
V_high_full = gathered[:full_k_high]
else:
# to_local() returned the full tensor (not FSDP-sharded)
V_high_full = local_V_high
else:
V_high_full = local_V_high
if can_cache:
# .detach() ensures plain Tensor, not nn.Parameter — avoids
# nn.Module.__setattr__ registering it into state_dict.
cache_holder._osft_v_high_full = V_high_full.detach()
# Factored projection: dV -= (dV @ V_high^T) @ V_high
# Step 1: dV_Vt = dV @ V_high^T → (rank_low, rank_high) — small intermediate
dV_Vt = torch.mm(local_dV, local_V_high.transpose(0, 1))
if dist.is_initialized() and dist.get_world_size() > 1:
dist.all_reduce(dV_Vt, op=dist.ReduceOp.SUM)

# Two local matmuls — no (M, M) intermediate
coeff = torch.mm(local_dV, V_high_full.transpose(0, 1)) # (k_low/P, k_high)
local_dV.addmm_(coeff, V_high_full, alpha=-1.0) # (k_low/P, M)
# Step 2: dV -= dV_Vt @ V_high — uses addmm_ to fuse subtraction
local_dV.addmm_(dV_Vt, local_V_high, alpha=-1.0)
Comment on lines +607 to +614

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

The new V projection is incorrect for FSDP2 Shard(0) tensors.

local_dV @ local_V_high.T is not a partial sum in this layout; it produces coefficients for each rank's local V_high rows. all_reduce(SUM) therefore mixes unrelated coefficient blocks from different ranks instead of reconstructing the global (rank_low, rank_high) projection, so V_low.grad is wrong on real sharded DTensors. The new tests still pass because they only exercise replicated tensors, not the row-sharded case.

🧭 Fix direction

Restore a layout-aware gather for the V basis/coefficient blocks here. all_reduce is valid for the U path because the contracted dimension is sharded; for V, with the current Shard(0) layout, you need either:

  • an all-gather of V_high rows before computing dV @ V_high.T, or
  • an all-gather/reassembly of coefficient blocks into their correct global column positions before applying the correction.

Also applies to: 2159-2187

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/mini_trainer/osft_utils.py` around lines 607 - 614, The V-projection is
wrong for FSDP2 Shard(0) layouts because computing dV_Vt = local_dV @
local_V_high.T and then all_reduce mixes coefficient blocks from different
ranks; fix by restoring a layout-aware gather: either all-gather the rows of
V_high across ranks before computing dV_Vt (so dV_Vt = local_dV @
gathered_V_high.T uses the global V_high rows) or instead all-gather/reassemble
the coefficient blocks dV_Vt into their correct global column positions before
applying the correction; update the code around local_dV, local_V_high and the
projection step (where dV_Vt is computed and reduced, and local_dV.addmm_ is
applied) to perform one of these gather-based approaches rather than the plain
all_reduce.


if hasattr(dV, "_local_tensor"):
dV._local_tensor.copy_(local_dV)
Expand Down Expand Up @@ -846,6 +796,7 @@ def _load_model_memory_efficient(
if load_dtype is None:
raise ValueError("error: model does not have a `torch_dtype` setting, please report this to the developers")
final_base_kwargs["torch_dtype"] = load_dtype
final_base_kwargs["dtype"] = load_dtype # transformers v5 uses dtype

# initialize params to instance the OSFT model
# global rank 0 process actually loads the model, and all other procs
Expand Down Expand Up @@ -2038,12 +1989,11 @@ def _reconstruct_weight(

def _factorized_linear(self, x, svd_dict, bias=None):
"""
Efficient factorized linear operation using SVD components.
Optimized factorized linear with addmm_ fusion.

Computes: x @ (U_high @ S_high @ V_high + U_low @ S_low @ V_low).T + bias
As: (x @ V_high.T) @ (S_high * U_high).T + (x @ V_low.T) @ (S_low * U_low).T
Using: 3 mm + 1 addmm_ + 2 mul_ = 6 kernel launches (fused where possible).
"""
# Extract components
U_high = svd_dict["U_high"]
S_high = svd_dict["S_high"]
V_high = svd_dict["V_high"]
Expand All @@ -2054,26 +2004,30 @@ def _factorized_linear(self, x, svd_dict, bias=None):
device = x.device
dtype = x.dtype

# Move to correct device (keep native dtype)
U_high = U_high.to(device=device)
S_high = S_high.to(device=device)
V_high = V_high.to(device=device)
U_low = U_low.to(device=device)
S_low = S_low.to(device=device)
V_low = V_low.to(device=device)

# High-rank path (frozen): x @ V_high.T -> (batch, seq, rank_high)
x_V_high = x @ V_high.transpose(0, 1)
result_high = (x_V_high * S_high) @ U_high.transpose(0, 1)
# Flatten x to 2D for efficient mm/addmm_
orig_shape = x.shape
x_2d = x.reshape(-1, orig_shape[-1])

# Low-rank path (trainable): x @ V_low.T -> (batch, seq, rank_low)
x_V_low = x @ V_low.transpose(0, 1)
result_low = (x_V_low * S_low) @ U_low.transpose(0, 1)
# High-rank path (frozen)
tmp_high = torch.mm(x_2d, V_high.transpose(0, 1))
tmp_high.mul_(S_high.unsqueeze(0))
result = torch.mm(tmp_high, U_high.transpose(0, 1))

# Combine both paths
result = result_high + result_low
# Low-rank path (trainable): fuse matmul + addition via addmm_
tmp_low = torch.mm(x_2d, V_low.transpose(0, 1))
tmp_low.mul_(S_low.unsqueeze(0))
result.addmm_(tmp_low, U_low.transpose(0, 1))

# Reshape back to original batch dims
result = result.reshape(*orig_shape[:-1], result.shape[-1])

# Add bias if present
if bias is not None:
result = result + bias.to(device=device, dtype=dtype)

Expand Down Expand Up @@ -2108,16 +2062,10 @@ def project_gradients(self):

This method should be called after backpropagation and before optimizer step.

In distributed mode, U projection coefficients are batched into a single
all-reduce instead of one per OSFT target. For Llama-8B with 224 targets
this reduces 224 U all-reduce kernel launches to 1, cutting latency from
collective launch overhead. V projections continue per-module.

When ``OSFT_CACHE_V=1`` is set, the all-gathered V_high tensor is
cached on each module after the first step. V_high is frozen, so
the cache is exact. This eliminates per-step V all-gather traffic.
Default is off because the cache is replicated on every FSDP2 rank
(~5.1 GB for Llama-8B, infeasible for 70B+).
In distributed mode, both U and V projection coefficients are batched
into a single all-reduce each, instead of one per OSFT target. For
Llama-8B with 224 targets this reduces 448 all-reduce kernel launches
to 2, cutting latency from collective launch overhead.
"""
is_distributed = dist.is_initialized() and dist.get_world_size() > 1

Expand Down Expand Up @@ -2202,29 +2150,41 @@ def project_gradients(self):
f"Batch split consumed {offset} elements but tensor has {batched.numel()}"
)

# V projections: per-module factored V all-gather.
# Reuse the shared function with skip_u=True to avoid code
# duplication — the V projection logic is identical to the
# non-batched path.
caches_populated_this_call = 0
for svd_dict, module in zip(svd_dicts, svd_modules):
had_cache = hasattr(module, "_osft_v_high_full")
project_gradient_to_orthogonal_space(svd_dict, skip_u=True, cache_holder=module)
if not had_cache and hasattr(module, "_osft_v_high_full"):
caches_populated_this_call += 1

if caches_populated_this_call > 0:
total_bytes = sum(
module._osft_v_high_full.nelement() * module._osft_v_high_full.element_size()
for module in self.modules()
if hasattr(module, "_osft_v_high_full")
)
log_rank_0(
f"Cached {caches_populated_this_call} V_high tensors "
f"({total_bytes / 1e9:.2f} GB). "
f"Subsequent steps skip V all-gathers. "
f"Set OSFT_CACHE_V=0 to disable."
)
# V projections: batched all-reduce (same pattern as U above).
# Factored form: dV -= (dV @ V_high^T) @ V_high
# Batch all (dV @ V_high^T) coefficients into a single all-reduce.
v_work = []
v_flat_parts = []

for svd_dict in svd_dicts:
if svd_dict["V_low"].grad is not None:
dV = svd_dict["V_low"].grad
V_high = svd_dict["V_high"]
local_V_high = getattr(V_high, "to_local", lambda x=V_high: x)()
local_dV = getattr(dV, "to_local", lambda x=dV: x)()

dV_Vt = torch.mm(local_dV, local_V_high.transpose(0, 1))
v_work.append((local_V_high, local_dV, dV, dV_Vt.shape))
v_flat_parts.append(dV_Vt.flatten())

if v_flat_parts:
batched_v = torch.cat(v_flat_parts)
dist.all_reduce(batched_v, op=dist.ReduceOp.SUM)

offset = 0
for local_V_high, local_dV, dV, coeff_shape in v_work:
numel = coeff_shape[0] * coeff_shape[1]
dV_Vt = batched_v[offset : offset + numel].reshape(coeff_shape)
offset += numel

local_dV.addmm_(dV_Vt, local_V_high, alpha=-1.0)

if hasattr(dV, "_local_tensor"):
dV._local_tensor.copy_(local_dV)
else:
dV.copy_(local_dV)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

assert offset == batched_v.numel()

def prepare_state_dict_for_save(self, state_dict):
"""Reconstruct dense weights into ``state_dict`` for saving with memory optimization."""
Expand Down
16 changes: 11 additions & 5 deletions src/mini_trainer/setup_model_for_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def get_model_save_dtype(
# correct mixed-precision settings. So to circumvent this, we load the
# original model's config separately
original_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
original_dtype = getattr(original_config, "torch_dtype", None)
original_dtype = getattr(original_config, "torch_dtype", None) or getattr(original_config, "dtype", None)

# HF models return a torch.dtype from this field, but docs mark it as an optional string
if original_dtype is not None and isinstance(original_dtype, str):
Expand Down Expand Up @@ -985,7 +985,8 @@ def setup_model(
) -> torch.nn.Module | OSFTModel:
base_model_args = {
"pretrained_model_name_or_path": model_name_or_path,
"torch_dtype": train_dtype, # Ensure models are loaded in the training dtype
"torch_dtype": train_dtype, # kept for internal OSFT code that reads this key
"dtype": train_dtype, # transformers v5 uses dtype for from_pretrained
"trust_remote_code": trust_remote_code,
}

Expand Down Expand Up @@ -1220,9 +1221,14 @@ def load_osft_model():
model = load_osft_model() if osft else load_standard_model()

# here we handle configuring the save_dtype
model.config.torch_dtype = get_model_save_dtype(save_dtype, model_name_or_path, trust_remote_code)
if not model.config.torch_dtype:
raise ValueError("error: model does not have a `torch_dtype` setting, cannot save model in this dtype")
_save_dtype = get_model_save_dtype(save_dtype, model_name_or_path, trust_remote_code)
if not _save_dtype:
raise ValueError("error: model does not have a `torch_dtype` setting, please report this to the developers")
# transformers v5 uses `dtype` instead of `torch_dtype`
if hasattr(model.config, "dtype"):
model.config.dtype = _save_dtype
else:
model.config.torch_dtype = _save_dtype

# Freeze GPT-OSS router parameters BEFORE FSDP2 setup to avoid uniformity issues
if is_gpt_oss:
Expand Down
17 changes: 12 additions & 5 deletions src/mini_trainer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,17 @@ def validate_training_state(
expected_param_dtype: Expected dtype for model parameters and gradients
expected_optimizer_dtype: Expected dtype for optimizer state (usually float32 for numerical stability)
"""
# FSDP2 MixedPrecisionPolicy may store params in fp32; allow this as valid
allowed_param_dtypes = {expected_param_dtype, torch.float32}
for name, param in model.named_parameters():
if param.requires_grad and param.dtype != expected_param_dtype:
raise ValueError(f"Parameter {name} is not in {expected_param_dtype}, got {param.dtype}")
if param.grad is not None and param.grad.dtype != expected_param_dtype:
raise ValueError(f"Gradient {name} is not in {expected_param_dtype}, got {param.grad.dtype}")
if param.requires_grad and param.dtype not in allowed_param_dtypes:
raise ValueError(
f"Parameter {name} has unexpected dtype {param.dtype}, expected one of {allowed_param_dtypes}"
)
if param.grad is not None and param.grad.dtype not in allowed_param_dtypes:
raise ValueError(
f"Gradient {name} has unexpected dtype {param.grad.dtype}, expected one of {allowed_param_dtypes}"
)

# Check optimizer state tensors - only for trainable parameters
for p_obj, state in optimizer.state.items():
Expand Down Expand Up @@ -93,7 +99,8 @@ def take_gradient_step(model, optimizer, lr_scheduler, expected_dtype=torch.floa
model,
optimizer,
expected_param_dtype=expected_dtype,
expected_optimizer_dtype=expected_dtype,
# Optimizer states (exp_avg, exp_avg_sq) are always fp32 for numerical stability
expected_optimizer_dtype=torch.float32,
)
optimizer.zero_grad()
return grad_norm
Expand Down
Loading
Loading