diff --git a/src/art/unsloth/train.py b/src/art/unsloth/train.py index 2d23a9d8..cea30e3a 100644 --- a/src/art/unsloth/train.py +++ b/src/art/unsloth/train.py @@ -313,6 +313,37 @@ def _canonicalize_upstream_metrics(metrics: dict[str, float]) -> dict[str, float } +def _get_dtype_for_autocasting(model: torch.nn.Module) -> torch.dtype: + if os.environ.get("UNSLOTH_FORCE_FLOAT32") == "1": + return torch.float16 + + match os.environ.get("ACCELERATE_MIXED_PRECISION"): + case "fp16": + return torch.float16 + case "bf16": + return torch.bfloat16 + case None: + pass + case mixed_precision: + raise AssertionError( + f"Unsupported ACCELERATE_MIXED_PRECISION={mixed_precision!r}" + ) + + dtype_numels: dict[torch.dtype, int] = defaultdict(int) + for param in model.parameters(): + if param.is_floating_point(): + dtype_numels[param.dtype] += param.numel() + + assert dtype_numels, "Expected model to have floating-point parameters" + model_dtype, _ = max(dtype_numels.items(), key=lambda item: item[1]) + if model_dtype == torch.bfloat16: + return torch.bfloat16 + if model_dtype in (torch.float16, torch.float32): + return torch.float16 + + raise AssertionError(f"Unsupported model dtype {model_dtype}") + + async def train( trainer: "GRPOTrainer", results_queue: asyncio.Queue[dict[str, float]], @@ -339,6 +370,9 @@ async def train( def get_compute_loss_fn(trainer: "GRPOTrainer") -> Callable[..., torch.Tensor]: + assert isinstance(trainer.model, torch.nn.Module) + dtype_for_autocasting = _get_dtype_for_autocasting(trainer.model) + def compute_loss( model: "PeftModel", inputs: "TrainInputs", @@ -379,18 +413,6 @@ def compute_loss( for key, tensor in inputs.items() } # ty:ignore[invalid-assignment] - accelerate_mixed_precision = os.environ.get("ACCELERATE_MIXED_PRECISION") - force_float32 = os.environ.get("UNSLOTH_FORCE_FLOAT32") - - if ( - accelerate_mixed_precision is None - or accelerate_mixed_precision == "fp16" - or force_float32 == "1" - ): - dtype_for_autocasting = torch.float16 - else: - dtype_for_autocasting = torch.bfloat16 - batch_size, seq_len = inputs["tokens"].size() attn_bias = calculate_attn_bias( batch_size, diff --git a/tests/unit/test_unsloth_autocast_dtype.py b/tests/unit/test_unsloth_autocast_dtype.py new file mode 100644 index 00000000..5438077f --- /dev/null +++ b/tests/unit/test_unsloth_autocast_dtype.py @@ -0,0 +1,55 @@ +import torch + +from art.unsloth.train import _get_dtype_for_autocasting + + +class _TinyModel(torch.nn.Module): + def __init__(self, dtype_numels: list[tuple[torch.dtype, int]]) -> None: + super().__init__() + self.params = torch.nn.ParameterList( + [ + torch.nn.Parameter(torch.empty(numel, dtype=dtype)) + for dtype, numel in dtype_numels + ] + ) + + +def test_get_dtype_for_autocasting_infers_bfloat16_model_when_env_unset( + monkeypatch, +) -> None: + monkeypatch.delenv("ACCELERATE_MIXED_PRECISION", raising=False) + monkeypatch.delenv("UNSLOTH_FORCE_FLOAT32", raising=False) + model = _TinyModel( + [ + (torch.bfloat16, 8), + (torch.float32, 1), + ] + ) + + assert _get_dtype_for_autocasting(model) == torch.bfloat16 + + +def test_get_dtype_for_autocasting_keeps_fp16_default_for_fp32_model( + monkeypatch, +) -> None: + monkeypatch.delenv("ACCELERATE_MIXED_PRECISION", raising=False) + monkeypatch.delenv("UNSLOTH_FORCE_FLOAT32", raising=False) + model = _TinyModel([(torch.float32, 8)]) + + assert _get_dtype_for_autocasting(model) == torch.float16 + + +def test_get_dtype_for_autocasting_honors_explicit_fp16(monkeypatch) -> None: + monkeypatch.setenv("ACCELERATE_MIXED_PRECISION", "fp16") + monkeypatch.delenv("UNSLOTH_FORCE_FLOAT32", raising=False) + model = _TinyModel([(torch.bfloat16, 8)]) + + assert _get_dtype_for_autocasting(model) == torch.float16 + + +def test_get_dtype_for_autocasting_honors_explicit_bfloat16(monkeypatch) -> None: + monkeypatch.setenv("ACCELERATE_MIXED_PRECISION", "bf16") + monkeypatch.delenv("UNSLOTH_FORCE_FLOAT32", raising=False) + model = _TinyModel([(torch.float16, 8)]) + + assert _get_dtype_for_autocasting(model) == torch.bfloat16