Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions src/art/unsloth/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,37 @@ def _canonicalize_upstream_metrics(metrics: dict[str, float]) -> dict[str, float
}


def _get_dtype_for_autocasting(model: torch.nn.Module) -> torch.dtype:
if os.environ.get("UNSLOTH_FORCE_FLOAT32") == "1":
return torch.float16

match os.environ.get("ACCELERATE_MIXED_PRECISION"):
case "fp16":
return torch.float16
case "bf16":
return torch.bfloat16
case None:
pass
case mixed_precision:
raise AssertionError(
f"Unsupported ACCELERATE_MIXED_PRECISION={mixed_precision!r}"
)

dtype_numels: dict[torch.dtype, int] = defaultdict(int)
for param in model.parameters():
if param.is_floating_point():
dtype_numels[param.dtype] += param.numel()

assert dtype_numels, "Expected model to have floating-point parameters"
model_dtype, _ = max(dtype_numels.items(), key=lambda item: item[1])
if model_dtype == torch.bfloat16:
return torch.bfloat16
if model_dtype in (torch.float16, torch.float32):
return torch.float16

raise AssertionError(f"Unsupported model dtype {model_dtype}")


async def train(
trainer: "GRPOTrainer",
results_queue: asyncio.Queue[dict[str, float]],
Expand All @@ -339,6 +370,9 @@ async def train(


def get_compute_loss_fn(trainer: "GRPOTrainer") -> Callable[..., torch.Tensor]:
assert isinstance(trainer.model, torch.nn.Module)
dtype_for_autocasting = _get_dtype_for_autocasting(trainer.model)

def compute_loss(
model: "PeftModel",
inputs: "TrainInputs",
Expand Down Expand Up @@ -379,18 +413,6 @@ def compute_loss(
for key, tensor in inputs.items()
} # ty:ignore[invalid-assignment]

accelerate_mixed_precision = os.environ.get("ACCELERATE_MIXED_PRECISION")
force_float32 = os.environ.get("UNSLOTH_FORCE_FLOAT32")

if (
accelerate_mixed_precision is None
or accelerate_mixed_precision == "fp16"
or force_float32 == "1"
):
dtype_for_autocasting = torch.float16
else:
dtype_for_autocasting = torch.bfloat16

batch_size, seq_len = inputs["tokens"].size()
attn_bias = calculate_attn_bias(
batch_size,
Expand Down
55 changes: 55 additions & 0 deletions tests/unit/test_unsloth_autocast_dtype.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import torch

from art.unsloth.train import _get_dtype_for_autocasting


class _TinyModel(torch.nn.Module):
def __init__(self, dtype_numels: list[tuple[torch.dtype, int]]) -> None:
super().__init__()
self.params = torch.nn.ParameterList(
[
torch.nn.Parameter(torch.empty(numel, dtype=dtype))
for dtype, numel in dtype_numels
]
)


def test_get_dtype_for_autocasting_infers_bfloat16_model_when_env_unset(
monkeypatch,
) -> None:
monkeypatch.delenv("ACCELERATE_MIXED_PRECISION", raising=False)
monkeypatch.delenv("UNSLOTH_FORCE_FLOAT32", raising=False)
model = _TinyModel(
[
(torch.bfloat16, 8),
(torch.float32, 1),
]
)

assert _get_dtype_for_autocasting(model) == torch.bfloat16


def test_get_dtype_for_autocasting_keeps_fp16_default_for_fp32_model(
monkeypatch,
) -> None:
monkeypatch.delenv("ACCELERATE_MIXED_PRECISION", raising=False)
monkeypatch.delenv("UNSLOTH_FORCE_FLOAT32", raising=False)
model = _TinyModel([(torch.float32, 8)])

assert _get_dtype_for_autocasting(model) == torch.float16


def test_get_dtype_for_autocasting_honors_explicit_fp16(monkeypatch) -> None:
monkeypatch.setenv("ACCELERATE_MIXED_PRECISION", "fp16")
monkeypatch.delenv("UNSLOTH_FORCE_FLOAT32", raising=False)
model = _TinyModel([(torch.bfloat16, 8)])

assert _get_dtype_for_autocasting(model) == torch.float16


def test_get_dtype_for_autocasting_honors_explicit_bfloat16(monkeypatch) -> None:
monkeypatch.setenv("ACCELERATE_MIXED_PRECISION", "bf16")
monkeypatch.delenv("UNSLOTH_FORCE_FLOAT32", raising=False)
model = _TinyModel([(torch.float16, 8)])

assert _get_dtype_for_autocasting(model) == torch.bfloat16
Loading