Skip to content

Commit 01f0a97

Browse files
authored
DTensor support for bfloat16 stochastic rounding (#3266)
* DTensor support for bfloat16 stochastic rounding * Mono import * Test DTensor bf16 stochastic round parity * ci fix
1 parent 226d7a4 commit 01f0a97

File tree

2 files changed

+58
-2
lines changed

2 files changed

+58
-2
lines changed

test/test_low_bit_optim.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,45 @@ def test_bf16_stochastic_round(self, device, compile):
119119
# must cast BF16 tensor back to FP32 so that .mean() is accurate
120120
torch.testing.assert_close(x_rep_bf16.float().mean(1), x, atol=3e-5, rtol=3e-5)
121121

122+
@parametrize("device", _DEVICES)
123+
@parametrize("compile", [False, True])
124+
def test_bf16_stochastic_round_dtensor(self, device, compile):
125+
pytest.importorskip("torch.distributed")
126+
import torch.distributed as dist
127+
from torch.distributed.device_mesh import init_device_mesh
128+
from torch.distributed.tensor import DTensor, Replicate
129+
130+
created_pg = False
131+
if dist.is_available() and not dist.is_initialized():
132+
store = dist.TCPStore("127.0.0.1", 29500, 1, True)
133+
dist.init_process_group(
134+
backend="gloo",
135+
store=store,
136+
rank=0,
137+
world_size=1,
138+
)
139+
created_pg = True
140+
141+
try:
142+
torch.manual_seed(common_utils.SEED)
143+
x = torch.rand(32, device=device) * 100
144+
x_rep = x.view(-1, 1).repeat(1, 100_000)
145+
146+
func = torch.compile(
147+
_fp32_to_bf16_sr, fullgraph=True, dynamic=False, disable=not compile
148+
)
149+
out_plain = func(x_rep)
150+
151+
mesh = init_device_mesh(device, (1,))
152+
x_dt = DTensor.from_local(x_rep, mesh, [Replicate()], run_check=False)
153+
out_dt = func(x_dt)
154+
155+
assert isinstance(out_dt, DTensor)
156+
torch.testing.assert_close(out_dt.to_local(), out_plain)
157+
finally:
158+
if created_pg:
159+
dist.destroy_process_group()
160+
122161

123162
class TestOptim(TestCase):
124163
@parametrize(

torchao/optim/quant_utils.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66
import torch
77
from torch import Tensor
8+
from torch.distributed.tensor import DTensor
89

910

1011
# https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391
@@ -117,7 +118,7 @@ def dequant_with_qmap(codes: Tensor, qmap: Tensor, scale: Tensor):
117118
return out.view(codes.shape)
118119

119120

120-
def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
121+
def _fp32_to_bf16_sr(_x_f32: Tensor) -> Tensor:
121122
# For an FP32 number [a31, ..., a16, a15, ..., a0] to be converted to BF16
122123
# - Round towards zero: [a31, ..., a16, 0, ..., 0]
123124
# - Round away from zero: [a31, ..., a16+1, 0, ..., 0]
@@ -127,6 +128,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
127128
# [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
128129
#
129130
# we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
131+
is_dt = isinstance(_x_f32, DTensor)
132+
x_f32 = _x_f32.to_local() if is_dt else _x_f32
133+
130134
rand_16bit = torch.randint(
131135
0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32
132136
)
@@ -142,4 +146,17 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
142146
)
143147
# alternative, slightly faster
144148
# x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000
145-
return x_f32_bits.view(torch.float32).bfloat16()
149+
x_bf16_trunc = x_f32_bits.view(torch.float32).bfloat16()
150+
151+
return (
152+
DTensor.from_local(
153+
x_bf16_trunc,
154+
_x_f32.device_mesh,
155+
_x_f32.placements,
156+
run_check=False,
157+
shape=tuple(_x_f32.shape),
158+
stride=tuple(_x_f32.stride()),
159+
)
160+
if is_dt
161+
else x_bf16_trunc
162+
)

0 commit comments

Comments
 (0)