55# LICENSE file in the root directory of this source tree.
66import torch
77from torch import Tensor
8+ try :
9+ from torch .distributed .tensor import DTensor
10+ except Exception :
11+ try :
12+ from torch .distributed ._tensor import DTensor
13+ except Exception :
14+ DTensor = tuple ()
815
916
1017# https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391
@@ -117,7 +124,7 @@ def dequant_with_qmap(codes: Tensor, qmap: Tensor, scale: Tensor):
117124 return out .view (codes .shape )
118125
119126
120- def _fp32_to_bf16_sr (x_f32 : Tensor ) -> Tensor :
127+ def _fp32_to_bf16_sr (_x_f32 : Tensor ) -> Tensor :
121128 # For an FP32 number [a31, ..., a16, a15, ..., a0] to be converted to BF16
122129 # - Round towards zero: [a31, ..., a16, 0, ..., 0]
123130 # - Round away from zero: [a31, ..., a16+1, 0, ..., 0]
@@ -127,6 +134,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
127134 # [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
128135 #
129136 # we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
137+ is_dt = isinstance (_x_f32 , DTensor )
138+ x_f32 = _x_f32 .to_local () if is_dt else _x_f32
139+
130140 rand_16bit = torch .randint (
131141 0 , 1 << 16 , x_f32 .shape , device = x_f32 .device , dtype = torch .int32
132142 )
@@ -142,4 +152,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
142152 )
143153 # alternative, slightly faster
144154 # x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000
145- return x_f32_bits .view (torch .float32 ).bfloat16 ()
155+ x_bf16_trunc = x_f32_bits .view (torch .float32 ).bfloat16 ()
156+
157+ return DTensor .from_local (
158+ x_bf16_trunc , _x_f32 .device_mesh , _x_f32 .placements ,
159+ run_check = False , shape = tuple (_x_f32 .shape ), stride = tuple (_x_f32 .stride ()),
160+ ) if is_dt else x_bf16_trunc
0 commit comments