From a7afa47266c78e257787e82c70f10af7f60bce65 Mon Sep 17 00:00:00 2001 From: mini-trainer-dev Date: Tue, 2 Jun 2026 06:10:05 +0000 Subject: [PATCH] Pass train_dtype into _load_model_memory_efficient MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The OSFT memory-efficient loading path was extracting the training dtype indirectly from base_kwargs["torch_dtype"] instead of accepting it as an explicit parameter. This made it fragile and inconsistent with the SFT distributed loading path which takes train_dtype directly. Thread train_dtype through: setup_model() → setup_osft_model_distributed() → from_pretrained() → _load_model_memory_efficient(), with backward- compatible fallback to torch_dtype from base_kwargs when not provided. Closes Red-Hat-AI-Innovation-Team/mini_trainer#34 Co-Authored-By: Claude Opus 4.6 (1M context) Co-authored-by: multica-agent --- src/mini_trainer/osft_utils.py | 13 ++- src/mini_trainer/setup_model_for_training.py | 4 + tests/test_osft.py | 97 ++++++++++++++++++++ 3 files changed, 110 insertions(+), 4 deletions(-) diff --git a/src/mini_trainer/osft_utils.py b/src/mini_trainer/osft_utils.py index c3106c3..2c40d0e 100644 --- a/src/mini_trainer/osft_utils.py +++ b/src/mini_trainer/osft_utils.py @@ -913,6 +913,7 @@ def _load_model_memory_efficient( model_args: tuple, base_kwargs: dict, osft_class_kwargs: dict, + train_dtype: torch.dtype | None = None, ): """ Memory-efficient loading for OSFT models to avoid CUDA/CPU OOM. @@ -927,8 +928,10 @@ def _load_model_memory_efficient( pretrained_model_name_or_path: Model path or name model_args: Positional arguments for model loading base_kwargs: Base model kwargs (already filtered) - init_cfg: OSFT configuration osft_class_kwargs: OSFT class-specific parameters + train_dtype: Training dtype for model parameters. When provided, the + model is loaded in this dtype instead of relying on torch_dtype + from base_kwargs. Returns: Loaded OSFT model @@ -952,9 +955,9 @@ def _load_model_memory_efficient( # Remove additional OSFT parameters before calling base model's from_pretrained final_base_kwargs = _filter_osft_parameters(base_kwargs, OSFT_BASE_MODEL_FILTERED_PARAMS) - # Force CPU loading via default behavior and match the train_dtype for FSDP2 - # Need to get train_dtype from base_kwargs or default to float32 - load_dtype = base_kwargs.get("torch_dtype") + # Use the explicit train_dtype when provided, otherwise fall back to + # torch_dtype from base_kwargs for backward compatibility. + load_dtype = train_dtype if train_dtype is not None else base_kwargs.get("torch_dtype") if load_dtype is None: raise ValueError("error: model does not have a `torch_dtype` setting, please report this to the developers") final_base_kwargs["torch_dtype"] = load_dtype @@ -1299,6 +1302,7 @@ def from_pretrained( log_rank_0("\033[33m!!!! Calling from_pretrained !!!!\033[0m") initialize_osft = kwargs.pop("initialize_osft", False) + train_dtype = kwargs.pop("train_dtype", None) # validation if fsdp2_lazy_init: @@ -1334,6 +1338,7 @@ def from_pretrained( model_args, base_kwargs, osft_class_kwargs, + train_dtype=train_dtype, ) else: # standard non-distributed loading diff --git a/src/mini_trainer/setup_model_for_training.py b/src/mini_trainer/setup_model_for_training.py index 526501b..3ad8064 100644 --- a/src/mini_trainer/setup_model_for_training.py +++ b/src/mini_trainer/setup_model_for_training.py @@ -779,6 +779,7 @@ def setup_osft_model_distributed( osft_target_patterns=None, osft_upcast_dtype=torch.float32, osft_output_dtype=None, + train_dtype: torch.dtype = torch.float32, ): """ Initialize an OSFT model for distributed training with memory-efficient loading. @@ -799,6 +800,7 @@ def setup_osft_model_distributed( osft_target_patterns: Patterns for selecting OSFT target parameters osft_upcast_dtype: Dtype for OSFT computations osft_output_dtype: Dtype for OSFT outputs + train_dtype: Training dtype for model parameters Returns: OSFT model ready for FSDP2 wrapping @@ -821,6 +823,7 @@ def setup_osft_model_distributed( **base_model_args, "initialize_osft": True, "fsdp2_lazy_init": True, + "train_dtype": train_dtype, **osft_kwargs, } @@ -1186,6 +1189,7 @@ def load_osft_model(): osft_target_patterns=osft_target_patterns, osft_upcast_dtype=osft_upcast_dtype, osft_output_dtype=effective_osft_output_dtype, + train_dtype=train_dtype, ) else: # non-distributed path: direct OSFT model creation diff --git a/tests/test_osft.py b/tests/test_osft.py index 7723d71..7622c49 100644 --- a/tests/test_osft.py +++ b/tests/test_osft.py @@ -2203,6 +2203,103 @@ def _align(model): assert align_mock.call_count == 1 assert loaded_models and loaded_models[0].aligned is True + def test_memory_efficient_loading_uses_train_dtype(self, monkeypatch): + """train_dtype parameter should override torch_dtype from base_kwargs.""" + captured_kwargs = {} + + class DummyLoadedModel(nn.Module): + def __init__(self): + super().__init__() + self.config = MagicMock() + self.config.vocab_size = 10 + + def state_dict(self): + return {"weight": torch.zeros(1)} + + def named_buffers(self): + return [("buffer", torch.zeros(1))] + + class DummyBase(nn.Module): + def __init__(self): + super().__init__() + + @classmethod + def from_pretrained(cls, *args, **kwargs): + captured_kwargs.update(kwargs) + return DummyLoadedModel() + + class DummyOSFT(DummyBase): + def __init__(self, config, **kwargs): + super().__init__() + self.config = config + self._lazy_init_pending = True + + monkeypatch.setattr(osft_module.dist, "is_available", lambda: True) + monkeypatch.setattr(osft_module.dist, "is_initialized", lambda: True) + monkeypatch.setattr(osft_module.dist, "get_rank", lambda: 0) + monkeypatch.setattr(osft_module.dist, "barrier", lambda: None) + monkeypatch.setattr(osft_module.dist, "broadcast_object_list", lambda *_, **__: None) + monkeypatch.setattr(osft_module.torch.cuda, "is_available", lambda: False) + + _load_model_memory_efficient( + actual_osft_cls=DummyOSFT, + pretrained_model_name_or_path="dummy", + model_args=tuple(), + base_kwargs={"torch_dtype": torch.float32}, + osft_class_kwargs={}, + train_dtype=torch.bfloat16, + ) + + assert captured_kwargs["torch_dtype"] == torch.bfloat16 + + def test_memory_efficient_loading_falls_back_to_base_kwargs(self, monkeypatch): + """Without train_dtype, should fall back to torch_dtype from base_kwargs.""" + captured_kwargs = {} + + class DummyLoadedModel(nn.Module): + def __init__(self): + super().__init__() + self.config = MagicMock() + self.config.vocab_size = 10 + + def state_dict(self): + return {"weight": torch.zeros(1)} + + def named_buffers(self): + return [("buffer", torch.zeros(1))] + + class DummyBase(nn.Module): + def __init__(self): + super().__init__() + + @classmethod + def from_pretrained(cls, *args, **kwargs): + captured_kwargs.update(kwargs) + return DummyLoadedModel() + + class DummyOSFT(DummyBase): + def __init__(self, config, **kwargs): + super().__init__() + self.config = config + self._lazy_init_pending = True + + monkeypatch.setattr(osft_module.dist, "is_available", lambda: True) + monkeypatch.setattr(osft_module.dist, "is_initialized", lambda: True) + monkeypatch.setattr(osft_module.dist, "get_rank", lambda: 0) + monkeypatch.setattr(osft_module.dist, "barrier", lambda: None) + monkeypatch.setattr(osft_module.dist, "broadcast_object_list", lambda *_, **__: None) + monkeypatch.setattr(osft_module.torch.cuda, "is_available", lambda: False) + + _load_model_memory_efficient( + actual_osft_cls=DummyOSFT, + pretrained_model_name_or_path="dummy", + model_args=tuple(), + base_kwargs={"torch_dtype": torch.float16}, + osft_class_kwargs={}, + ) + + assert captured_kwargs["torch_dtype"] == torch.float16 + class TestPostStepParameterProjection: """Test post-step parameter re-projection to fix AdamW subspace leak.