Skip to content

Commit ec99f95

Browse files
gau-nernstHanxian97
authored andcommitted
Improve FSDP support for low-bit optimizers (#538)
1 parent 6718cf5 commit ec99f95

File tree

6 files changed

+153
-29
lines changed

6 files changed

+153
-29
lines changed

test/prototype/test_low_bit_optim.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,15 @@ def _test_fsdp2(self, optim_cls):
226226
base_optim.step()
227227
self.assertEqual(fsdp_loss, base_loss)
228228

229+
base_param = base_optim.param_groups[0]["params"][0]
230+
base_exp_avg = base_optim.state[base_param]["exp_avg"]
231+
232+
fsdp_param = fsdp_optim.param_groups[0]["params"][0]
233+
fsdp_exp_avg = fsdp_optim.state[fsdp_param]["exp_avg"]
234+
full_fsdp_exp_avg = fsdp_exp_avg.full_tensor()
235+
236+
self.assertEqual(base_exp_avg.dequantize(), full_fsdp_exp_avg.dequantize())
237+
229238

230239
instantiate_parametrized_tests(TestQuantize)
231240
instantiate_parametrized_tests(TestOptim)

torchao/prototype/low_bit_optim/adam.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
3939
def _new_buffer(self, p: Tensor, signed: bool):
4040
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
4141
if isinstance(p, DTensor):
42-
out = torch.empty_like(p)
43-
out._local_tensor = self._subclass_zeros(
44-
out._local_tensor,
45-
signed,
46-
self.block_size,
42+
out = DTensor.from_local(
43+
local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size),
44+
device_mesh=p.device_mesh,
45+
placements=p.placements,
46+
run_check=False,
4747
)
4848
else:
4949
out = self._subclass_zeros(p, signed, self.block_size)

torchao/prototype/low_bit_optim/adamw.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
3939
def _new_buffer(self, p: Tensor, signed: bool):
4040
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
4141
if isinstance(p, DTensor):
42-
out = torch.empty_like(p)
43-
out._local_tensor = self._subclass_zeros(
44-
out._local_tensor,
45-
signed,
46-
self.block_size,
42+
out = DTensor.from_local(
43+
local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size),
44+
device_mesh=p.device_mesh,
45+
placements=p.placements,
46+
run_check=False,
4747
)
4848
else:
4949
out = self._subclass_zeros(p, signed, self.block_size)

torchao/prototype/low_bit_optim/subclass_4bit.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88

99

1010
aten = torch.ops.aten
11-
11+
c10d_functional = torch.ops.c10d_functional
12+
_c10d_functional = torch.ops._c10d_functional
1213

1314
# https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml
1415
# NOTE: power-1 is linear
@@ -31,17 +32,29 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape
3132
)
3233

3334
def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, shape):
35+
"""Create quantized 4-bit optimizer state as proposed in https://arxiv.org/abs/2309.01507
36+
37+
Args
38+
codes: quantized and packed 4-bit data stored as uint8.
39+
scale: scale data for block-wise quantization.
40+
qmap: lookup table that maps between quantized value (code) and float value.
41+
signed: whether the tensor is signed or unsigned.
42+
shape: shape of original float tensor.
43+
44+
NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size).
45+
Thus, the last dimension of the original float tensor is not necessarily divisible by block size.
46+
Given `codes` and `scale`, `block_size` is calculated as `codes.numel() * 2 // scale.numel()`.
47+
The extra `* 2` is because `codes` is 4-bit data packed in 8-bit storage.
48+
"""
3449
assert codes.dtype is torch.uint8
3550
assert codes.ndim == 1 # flattened buffer
51+
assert scale.ndim == 1
3652
self.codes = codes
3753
self.scale = scale
3854
self.qmap = qmap
3955
self.signed = signed
4056
self._shape = shape
41-
42-
@property
43-
def block_size(self):
44-
return self.codes.numel() * 2 // self.scale.numel()
57+
self.block_size = codes.numel() * 2 // scale.numel()
4558

4659
def __tensor_flatten__(self):
4760
return self.tensor_attrs, [self.signed, self._shape]
@@ -113,9 +126,37 @@ def _(func, *args, **kwargs):
113126
return func(*args, **kwargs)
114127

115128

129+
# this is needed for DTensor.from_local() and for flattening tensor
116130
@OptimState4bit.implements(aten.view.default)
117131
def _(func, *args, **kwargs):
118132
x, shape = args
119-
if len(shape) > 1 or shape[0] != -1:
120-
raise ValueError(f"{x.__class__.__name__} only supports .view() with shape=[-1]")
121-
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),))
133+
134+
if tuple(x.shape) == tuple(shape):
135+
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, x._shape)
136+
137+
if len(shape) == 1 and shape[0] == -1:
138+
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),))
139+
140+
raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]")
141+
142+
143+
# this is needed for DTensor.full_tensor()
144+
@OptimState4bit.implements([
145+
c10d_functional.all_gather_into_tensor.default,
146+
_c10d_functional.all_gather_into_tensor.default,
147+
c10d_functional.wait_tensor.default,
148+
_c10d_functional.wait_tensor.default,
149+
])
150+
def _(func, *args, **kwargs):
151+
x = args[0]
152+
if not isinstance(x, OptimState4bit):
153+
raise ValueError(f"expecting a OptimState4bit but found {type(x)}")
154+
155+
codes = func(x.codes, *args[1:], **kwargs)
156+
scale = func(x.scale, *args[1:], **kwargs)
157+
158+
# adjust the first dim
159+
shape = (x._shape[0] * codes.numel() // x.codes.numel(),) + x._shape[1:]
160+
161+
# assume tensors from all ranks have the same signedness
162+
return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape)

torchao/prototype/low_bit_optim/subclass_8bit.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@
66

77

88
aten = torch.ops.aten
9+
c10d_functional = torch.ops.c10d_functional
10+
_c10d_functional = torch.ops._c10d_functional
911

1012
QMAP_SIGNED = create_dynamic_map(signed=True)
1113
QMAP_UNSIGNED = create_dynamic_map(signed=False)
1214

1315

14-
# dynamic tree quantization
15-
# https://arxiv.org/pdf/1511.04561
16-
# https://arxiv.org/abs/2110.02861
1716
class OptimState8bit(Tensor):
1817
implements = classmethod(_implements)
1918
tensor_attrs = ["codes", "scale", "qmap"]
@@ -28,15 +27,25 @@ def __new__(cls, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
2827
)
2928

3029
def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
30+
"""Create quantized 8-bit optimizer state as proposed in https://arxiv.org/abs/2110.02861
31+
32+
Args
33+
codes: quantized 8-bit data stored as uint8. Has the same shape as the original float tensor.
34+
scale: scale data for block-wise quantization.
35+
qmap: lookup table that maps between quantized value (code) and float value.
36+
signed: whether the tensor is signed or unsigned.
37+
38+
NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size).
39+
Thus, the last dimension of the original float tensor is not necessarily divisible by block size.
40+
Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`.
41+
"""
3142
assert codes.dtype is torch.uint8
43+
assert scale.ndim == 1
3244
self.codes = codes
3345
self.scale = scale
3446
self.qmap = qmap
3547
self.signed = signed
36-
37-
@property
38-
def block_size(self):
39-
return self.codes.numel() // self.scale.numel()
48+
self.block_size = codes.numel() // scale.numel()
4049

4150
def __tensor_flatten__(self):
4251
return self.tensor_attrs, [self.signed]
@@ -97,3 +106,31 @@ def _(func, *args, **kwargs):
97106
def _(func, *args, **kwargs):
98107
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args]
99108
return func(*args, **kwargs)
109+
110+
111+
# this is needed for DTensor.from_local()
112+
@OptimState8bit.implements(aten.view.default)
113+
def _(func, *args, **kwargs):
114+
x, shape = args
115+
return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed)
116+
117+
118+
# this is needed for DTensor.full_tensor()
119+
@OptimState8bit.implements([
120+
c10d_functional.all_gather_into_tensor.default,
121+
_c10d_functional.all_gather_into_tensor.default,
122+
c10d_functional.wait_tensor.default,
123+
_c10d_functional.wait_tensor.default,
124+
])
125+
def _(func, *args, **kwargs):
126+
x = args[0]
127+
if not isinstance(x, OptimState8bit):
128+
raise ValueError(f"expecting a OptimState8bit but found {type(x)}")
129+
130+
# assume tensors from all ranks have the same signedness
131+
return OptimState8bit(
132+
func(x.codes, *args[1:], **kwargs),
133+
func(x.scale, *args[1:], **kwargs),
134+
x.qmap.clone(),
135+
x.signed,
136+
)

torchao/prototype/low_bit_optim/subclass_fp8.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55

66
aten = torch.ops.aten
7+
c10d_functional = torch.ops.c10d_functional
8+
_c10d_functional = torch.ops._c10d_functional
9+
710
DTYPE = torch.float8_e4m3fn
811

912

@@ -32,13 +35,21 @@ def __new__(cls, codes: Tensor, scale: Tensor):
3235
)
3336

3437
def __init__(self, codes: Tensor, scale: Tensor):
38+
"""Create quantized FP8 optimizer state.
39+
40+
Args
41+
codes: quantized FP8 E4M3FN data. Has the same shape as the original float tensor.
42+
scale: scale data for block-wise quantization.
43+
44+
NOTE: To get block-wise scale, the original float tensor is first reshape to (-1, block_size).
45+
Thus, the last dimension of the original float tensor is not necessarily divisible by block size.
46+
Given `codes` and `scale`, `block_size` is calculated as `codes.numel() // scale.numel()`.
47+
"""
3548
assert codes.dtype is DTYPE
49+
assert scale.ndim == 1
3650
self.codes = codes
3751
self.scale = scale
38-
39-
@property
40-
def block_size(self):
41-
return self.codes.numel() // self.scale.numel()
52+
self.block_size = codes.numel() // scale.numel()
4253

4354
def __tensor_flatten__(self):
4455
return self.tensor_attrs, []
@@ -99,3 +110,29 @@ def _(func, *args, **kwargs):
99110
def _(func, *args, **kwargs):
100111
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]
101112
return func(*args, **kwargs)
113+
114+
115+
# this is needed for DTensor.from_local()
116+
@OptimStateFp8.implements(aten.view.default)
117+
def _(func, *args, **kwargs):
118+
x, shape = args
119+
return OptimStateFp8(x.codes.view(shape), x.scale)
120+
121+
122+
# this is needed for DTensor.full_tensor()
123+
@OptimStateFp8.implements([
124+
c10d_functional.all_gather_into_tensor.default,
125+
_c10d_functional.all_gather_into_tensor.default,
126+
c10d_functional.wait_tensor.default,
127+
_c10d_functional.wait_tensor.default,
128+
])
129+
def _(func, *args, **kwargs):
130+
x = args[0]
131+
if not isinstance(x, OptimStateFp8):
132+
raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}")
133+
134+
# assume tensors from all ranks have the same signedness
135+
return OptimStateFp8(
136+
func(x.codes, *args[1:], **kwargs),
137+
func(x.scale, *args[1:], **kwargs),
138+
)

0 commit comments

Comments
 (0)