Skip to content

Commit 82a05bd

Browse files
committed
DTensor support for bfloat16 stochastic rounding
1 parent 1e473ed commit 82a05bd

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

torchao/optim/quant_utils.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
# LICENSE file in the root directory of this source tree.
66
import torch
77
from torch import Tensor
8+
try:
9+
from torch.distributed.tensor import DTensor
10+
except Exception:
11+
try:
12+
from torch.distributed._tensor import DTensor
13+
except Exception:
14+
DTensor = tuple()
815

916

1017
# https://github.com/TimDettmers/bitsandbytes/blob/dada530149212d64d4b69534716202659ef37ec8/bitsandbytes/functional.py#L339-L391
@@ -117,7 +124,7 @@ def dequant_with_qmap(codes: Tensor, qmap: Tensor, scale: Tensor):
117124
return out.view(codes.shape)
118125

119126

120-
def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
127+
def _fp32_to_bf16_sr(_x_f32: Tensor) -> Tensor:
121128
# For an FP32 number [a31, ..., a16, a15, ..., a0] to be converted to BF16
122129
# - Round towards zero: [a31, ..., a16, 0, ..., 0]
123130
# - Round away from zero: [a31, ..., a16+1, 0, ..., 0]
@@ -127,6 +134,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
127134
# [a15, ..., a0] / 2^16, where the bit pattern [a15, ..., a0] is interpreted as uint16
128135
#
129136
# we have to use int32 since most arithmetic ops are not implemented for uint32/int16/uint16
137+
is_dt = isinstance(_x_f32, DTensor)
138+
x_f32 = _x_f32.to_local() if is_dt else _x_f32
139+
130140
rand_16bit = torch.randint(
131141
0, 1 << 16, x_f32.shape, device=x_f32.device, dtype=torch.int32
132142
)
@@ -142,4 +152,9 @@ def _fp32_to_bf16_sr(x_f32: Tensor) -> Tensor:
142152
)
143153
# alternative, slightly faster
144154
# x_f32_bits = (x_f32_bits + rand_16bit) & 0xFFFF0000
145-
return x_f32_bits.view(torch.float32).bfloat16()
155+
x_bf16_trunc = x_f32_bits.view(torch.float32).bfloat16()
156+
157+
return DTensor.from_local(
158+
x_bf16_trunc, _x_f32.device_mesh, _x_f32.placements,
159+
run_check=False, shape=tuple(_x_f32.shape), stride=tuple(_x_f32.stride()),
160+
) if is_dt else x_bf16_trunc

0 commit comments

Comments
 (0)