Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/mini_trainer/osft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,6 +913,7 @@ def _load_model_memory_efficient(
model_args: tuple,
base_kwargs: dict,
osft_class_kwargs: dict,
train_dtype: torch.dtype | None = None,
):
"""
Memory-efficient loading for OSFT models to avoid CUDA/CPU OOM.
Expand All @@ -927,8 +928,10 @@ def _load_model_memory_efficient(
pretrained_model_name_or_path: Model path or name
model_args: Positional arguments for model loading
base_kwargs: Base model kwargs (already filtered)
init_cfg: OSFT configuration
osft_class_kwargs: OSFT class-specific parameters
train_dtype: Training dtype for model parameters. When provided, the
model is loaded in this dtype instead of relying on torch_dtype
from base_kwargs.

Returns:
Loaded OSFT model
Expand All @@ -952,9 +955,9 @@ def _load_model_memory_efficient(
# Remove additional OSFT parameters before calling base model's from_pretrained
final_base_kwargs = _filter_osft_parameters(base_kwargs, OSFT_BASE_MODEL_FILTERED_PARAMS)

# Force CPU loading via default behavior and match the train_dtype for FSDP2
# Need to get train_dtype from base_kwargs or default to float32
load_dtype = base_kwargs.get("torch_dtype")
# Use the explicit train_dtype when provided, otherwise fall back to
# torch_dtype from base_kwargs for backward compatibility.
load_dtype = train_dtype if train_dtype is not None else base_kwargs.get("torch_dtype")
if load_dtype is None:
raise ValueError("error: model does not have a `torch_dtype` setting, please report this to the developers")
final_base_kwargs["torch_dtype"] = load_dtype
Expand Down Expand Up @@ -1299,6 +1302,7 @@ def from_pretrained(
log_rank_0("\033[33m!!!! Calling from_pretrained !!!!\033[0m")

initialize_osft = kwargs.pop("initialize_osft", False)
train_dtype = kwargs.pop("train_dtype", None)

# validation
if fsdp2_lazy_init:
Expand Down Expand Up @@ -1334,6 +1338,7 @@ def from_pretrained(
model_args,
base_kwargs,
osft_class_kwargs,
train_dtype=train_dtype,
)
else:
# standard non-distributed loading
Expand Down
4 changes: 4 additions & 0 deletions src/mini_trainer/setup_model_for_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,7 @@ def setup_osft_model_distributed(
osft_target_patterns=None,
osft_upcast_dtype=torch.float32,
osft_output_dtype=None,
train_dtype: torch.dtype = torch.float32,
):
"""
Initialize an OSFT model for distributed training with memory-efficient loading.
Expand All @@ -799,6 +800,7 @@ def setup_osft_model_distributed(
osft_target_patterns: Patterns for selecting OSFT target parameters
osft_upcast_dtype: Dtype for OSFT computations
osft_output_dtype: Dtype for OSFT outputs
train_dtype: Training dtype for model parameters

Returns:
OSFT model ready for FSDP2 wrapping
Expand All @@ -821,6 +823,7 @@ def setup_osft_model_distributed(
**base_model_args,
"initialize_osft": True,
"fsdp2_lazy_init": True,
"train_dtype": train_dtype,
**osft_kwargs,
}

Expand Down Expand Up @@ -1186,6 +1189,7 @@ def load_osft_model():
osft_target_patterns=osft_target_patterns,
osft_upcast_dtype=osft_upcast_dtype,
osft_output_dtype=effective_osft_output_dtype,
train_dtype=train_dtype,
)
else:
# non-distributed path: direct OSFT model creation
Expand Down
97 changes: 97 additions & 0 deletions tests/test_osft.py
Original file line number Diff line number Diff line change
Expand Up @@ -2203,6 +2203,103 @@ def _align(model):
assert align_mock.call_count == 1
assert loaded_models and loaded_models[0].aligned is True

def test_memory_efficient_loading_uses_train_dtype(self, monkeypatch):
"""train_dtype parameter should override torch_dtype from base_kwargs."""
captured_kwargs = {}

class DummyLoadedModel(nn.Module):
def __init__(self):
super().__init__()
self.config = MagicMock()
self.config.vocab_size = 10

def state_dict(self):
return {"weight": torch.zeros(1)}

def named_buffers(self):
return [("buffer", torch.zeros(1))]

class DummyBase(nn.Module):
def __init__(self):
super().__init__()

@classmethod
def from_pretrained(cls, *args, **kwargs):
captured_kwargs.update(kwargs)
return DummyLoadedModel()

class DummyOSFT(DummyBase):
def __init__(self, config, **kwargs):
super().__init__()
self.config = config
self._lazy_init_pending = True

monkeypatch.setattr(osft_module.dist, "is_available", lambda: True)
monkeypatch.setattr(osft_module.dist, "is_initialized", lambda: True)
monkeypatch.setattr(osft_module.dist, "get_rank", lambda: 0)
monkeypatch.setattr(osft_module.dist, "barrier", lambda: None)
monkeypatch.setattr(osft_module.dist, "broadcast_object_list", lambda *_, **__: None)
monkeypatch.setattr(osft_module.torch.cuda, "is_available", lambda: False)

_load_model_memory_efficient(
actual_osft_cls=DummyOSFT,
pretrained_model_name_or_path="dummy",
model_args=tuple(),
base_kwargs={"torch_dtype": torch.float32},
osft_class_kwargs={},
train_dtype=torch.bfloat16,
)

assert captured_kwargs["torch_dtype"] == torch.bfloat16

def test_memory_efficient_loading_falls_back_to_base_kwargs(self, monkeypatch):
"""Without train_dtype, should fall back to torch_dtype from base_kwargs."""
captured_kwargs = {}

class DummyLoadedModel(nn.Module):
def __init__(self):
super().__init__()
self.config = MagicMock()
self.config.vocab_size = 10

def state_dict(self):
return {"weight": torch.zeros(1)}

def named_buffers(self):
return [("buffer", torch.zeros(1))]

class DummyBase(nn.Module):
def __init__(self):
super().__init__()

@classmethod
def from_pretrained(cls, *args, **kwargs):
captured_kwargs.update(kwargs)
return DummyLoadedModel()

class DummyOSFT(DummyBase):
def __init__(self, config, **kwargs):
super().__init__()
self.config = config
self._lazy_init_pending = True

monkeypatch.setattr(osft_module.dist, "is_available", lambda: True)
monkeypatch.setattr(osft_module.dist, "is_initialized", lambda: True)
monkeypatch.setattr(osft_module.dist, "get_rank", lambda: 0)
monkeypatch.setattr(osft_module.dist, "barrier", lambda: None)
monkeypatch.setattr(osft_module.dist, "broadcast_object_list", lambda *_, **__: None)
monkeypatch.setattr(osft_module.torch.cuda, "is_available", lambda: False)

_load_model_memory_efficient(
actual_osft_cls=DummyOSFT,
pretrained_model_name_or_path="dummy",
model_args=tuple(),
base_kwargs={"torch_dtype": torch.float16},
osft_class_kwargs={},
)

assert captured_kwargs["torch_dtype"] == torch.float16


class TestPostStepParameterProjection:
"""Test post-step parameter re-projection to fix AdamW subspace leak.
Expand Down
Loading