55# LICENSE file in the root directory of this source tree.
66import torch
77from torch import Tensor
8+ from torch .distributed .tensor import DTensor
89
910
1011# https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391
@@ -117,7 +118,7 @@ def dequant_with_qmap(codes: Tensor, qmap: Tensor, scale: Tensor):
117118 return out .view (codes .shape )
118119
119120
120- def _fp32_to_bf16_sr (x_f32 : Tensor ) -> Tensor :
121+ def _fp32_to_bf16_sr (_x_f32 : Tensor ) -> Tensor :
121122 # For an FP32 number [a31, ..., a16, a15, ..., a0] to be converted to BF16
122123 # - Round towards zero: [a31, ..., a16, 0, ..., 0]
123124 # - Round away from zero: [a31, ..., a16+1, 0, ..., 0]
@@ -127,6 +128,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
127128 # [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
128129 #
129130 # we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
131+ is_dt = isinstance (_x_f32 , DTensor )
132+ x_f32 = _x_f32 .to_local () if is_dt else _x_f32
133+
130134 rand_16bit = torch .randint (
131135 0 , 1 << 16 , x_f32 .shape , device = x_f32 .device , dtype = torch .int32
132136 )
@@ -142,4 +146,17 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
142146 )
143147 # alternative, slightly faster
144148 # x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000
145- return x_f32_bits .view (torch .float32 ).bfloat16 ()
149+ x_bf16_trunc = x_f32_bits .view (torch .float32 ).bfloat16 ()
150+
151+ return (
152+ DTensor .from_local (
153+ x_bf16_trunc ,
154+ _x_f32 .device_mesh ,
155+ _x_f32 .placements ,
156+ run_check = False ,
157+ shape = tuple (_x_f32 .shape ),
158+ stride = tuple (_x_f32 .stride ()),
159+ )
160+ if is_dt
161+ else x_bf16_trunc
162+ )
0 commit comments