Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
import unittest

import torch
from torch.testing._internal.common_utils import (
TestCase,
instantiate_parametrized_tests,
parametrize,
run_tests,
)

from torchao.quantization.granularity import PerRow
from torchao.quantization.quant_primitives import (
Expand All @@ -21,6 +27,8 @@
_fake_quantize_affine_cachemask,
_maybe_expand_scale_to_tensor_shape,
_quantize_affine_float8,
_Round,
_StochasticRound,
choose_qparams_affine,
dequantize_affine,
quantize_affine,
Expand Down Expand Up @@ -193,7 +201,7 @@ def _groupwise_affine_dequantize_tensor_from_qparams(
return w_dq


class TestQuantPrimitives(unittest.TestCase):
class TestQuantPrimitives(TestCase):
SEED = 123

def test_get_group_qparams_symmetric(self):
Expand Down Expand Up @@ -913,6 +921,23 @@ def test_float8_rowwise_scaling_3d_weight_axis_1(self):
assert scale.shape == (B, 1, N)
assert data.shape == (B, K, N)

@parametrize("round_fn", [_Round, _StochasticRound])
def test_round_functions(self, round_fn):
x = torch.tensor([1.3, 2.7, -1.6, -2.2], dtype=torch.float32)
x_samples = x.view(1, -1).repeat(10000, 1)
rounded_samples = round_fn.apply(x_samples)

assert rounded_samples.dtype == x.dtype
torch.testing.assert_close(
rounded_samples, rounded_samples.round(), atol=0, rtol=0
)

# Unbiased property only holds for stochastic rounding
if round_fn == _StochasticRound:
torch.testing.assert_close(rounded_samples.mean(0), x, atol=5e-2, rtol=5e-2)


instantiate_parametrized_tests(TestQuantPrimitives)

if __name__ == "__main__":
unittest.main()
run_tests()
3 changes: 2 additions & 1 deletion torchao/prototype/quantized_training/int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.utils._python_dispatch import return_and_correct_aliasing

from torchao.core.config import AOBaseConfig
from torchao.quantization.quant_primitives import _StochasticRound
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
Expand Down Expand Up @@ -44,7 +45,7 @@ def quantize_int8_rowwise(
) # slightly faster than divide directly

if stochastic_rounding:
tensor = (tensor + torch.rand_like(tensor)).floor()
tensor = _StochasticRound.apply(tensor)
else:
tensor = tensor.round()

Expand Down
16 changes: 16 additions & 0 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,22 @@ def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
return gy


class _StochasticRound(torch.autograd.Function):
"""
Implementation of stochastic round operation with backward STE.
"""

@staticmethod
def forward(ctx, x: torch.Tensor) -> torch.Tensor:
floor_x = torch.floor(x)
prob = x - floor_x # fractional part as probability
return floor_x + (torch.rand_like(x) < prob).to(x.dtype)

@staticmethod
def backward(ctx, gy: torch.Tensor) -> torch.Tensor:
return gy


class _RoundToFloat8(torch.autograd.Function):
"""
Implementation of `tensor.to(float8_dtype)` with backward STE.
Expand Down