Skip to content

Commit 65f660d

Browse files
authored
[low-bit optim] Fix Adam4bit support on PyTorch 2.3 and 2.4. Update AdamFp8 torch requirement (#755)
* update doc on torch version * update doc * update * fix 4-bit problem * update doc * update
1 parent e5246fc commit 65f660d

File tree

3 files changed

+29
-16
lines changed

3 files changed

+29
-16
lines changed

test/prototype/test_low_bit_optim.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,15 +75,16 @@ def test_quantize_4bit_with_qmap_compile(self, device):
7575

7676

7777
class TestOptim(TestCase):
78-
@pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3")
78+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
7979
@parametrize("optim_name", ["Adam8bit", "AdamW8bit", "Adam4bit", "AdamW4bit", "AdamFp8", "AdamWFp8"])
8080
@parametrize("dtype", [torch.float32, torch.bfloat16])
8181
@parametrize("device", _DEVICES)
8282
def test_optim_smoke(self, optim_name, dtype, device):
83-
if optim_name.endswith("Fp8") and device == "cuda" and torch.cuda.get_device_capability() < (8, 9):
84-
pytest.skip("FP8 requires compute capability >= 8.9")
85-
if optim_name.endswith("4bit") and not TORCH_VERSION_AT_LEAST_2_5:
86-
pytest.skip("4-bit Adam requires PyTorch > 2.4")
83+
if optim_name.endswith("Fp8") and device == "cuda":
84+
if not TORCH_VERSION_AT_LEAST_2_4:
85+
pytest.skip("FP8 CUDA requires PyTorch >= 2.4")
86+
if torch.cuda.get_device_capability() < (8, 9):
87+
pytest.skip("FP8 requires compute capability >= 8.9")
8788

8889
# reset cache to avoid hitting cache_size_limit, since the function will re-compile for each test
8990
torch._dynamo.reset_code_caches()
@@ -100,7 +101,7 @@ def test_optim_smoke(self, optim_name, dtype, device):
100101

101102
@pytest.mark.skipif(bnb is None, reason="bitsandbytes is not availablle")
102103
@pytest.mark.skipif(not torch.cuda.is_available(), reason="bitsandbytes 8-bit Adam only works for CUDA")
103-
@pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3")
104+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
104105
@parametrize("optim_name", ["Adam8bit", "AdamW8bit"])
105106
def test_optim_8bit_correctness(self, optim_name):
106107
device = "cuda"
@@ -126,9 +127,10 @@ def test_optim_8bit_correctness(self, optim_name):
126127
for p1, p2 in zip(model1.parameters(), model2.parameters()):
127128
torch.testing.assert_close(p2, p1, rtol=1e-5, atol=1e-5)
128129

130+
# this will not run in CI because we can't install lpmm
129131
@pytest.mark.skipif(lpmm is None, reason="lpmm is not availablle")
130132
@pytest.mark.skipif(not torch.cuda.is_available(), reason="lpmm 4-bit Adam only works for CUDA")
131-
@pytest.mark.xfail(not TORCH_VERSION_AT_LEAST_2_3, reason="torch.compile() fails for PyTorch < 2.3")
133+
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_3, reason="requires PyTorch >= 2.3")
132134
@parametrize("optim_name", ["Adam4bit", "AdamW4bit"])
133135
def test_optim_4bit_correctness(self, optim_name):
134136
device = "cuda"

torchao/prototype/low_bit_optim/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ To use 4-bit Adam, replace the above with `Adam4bit`. Similarly for `AdamFp8`. Y
2424
**Other optimizers**: AdamW is also available as `AdamW8bit`, `AdamW4bit`, and `AdamWFp8`. Other optimizers can be added based on demand.
2525

2626
NOTE:
27-
- The low-bit optimizers require PyTorch >= 2.3. FP8 optimizers require CUDA compute capability >= 8.9.
27+
- The low-bit optimizers require PyTorch >= 2.3
28+
- For FP8 optimizers on CUDA, PyTorch >= 2.4 and CUDA compute capability >= 8.9 are required.
2829
- For 4-bit optimizers, we don't implement rank-1 normalization for quantizing 2nd moment as originally done in the paper.
2930
- The first training step is expected to be slow since the optimizer needs to be compiled.
3031

torchao/prototype/low_bit_optim/adam.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def __init__(
198198

199199
@staticmethod
200200
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
201-
return OptimState4bit.zeros(p.shape, signed, block_size, p.device)
201+
return OptimState4bit.zeros(p.view(-1).shape, signed, block_size, p.device)
202202

203203
@staticmethod
204204
def _unwrap_dtensor(p: Tensor):
@@ -216,6 +216,11 @@ def step(self, closure=None):
216216
# NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim.
217217
# thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param.
218218

219+
# NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for
220+
# PyTorch 2.3 and 2.4
221+
# calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op
222+
# correctly for the tensor subclass.
223+
219224
# unwrap DTensor since DTensor does not work well with dynamic compile
220225
# flatten p, grad, and optim state to avoid recompilation
221226
for group, lr, (beta1, beta2), weight_decay, eps in param_groups:
@@ -227,9 +232,9 @@ def step(self, closure=None):
227232
self._unwrap_dtensor(p).view(-1),
228233
self._unwrap_dtensor(grad).view(-1),
229234
step,
230-
self._unwrap_dtensor(exp_avg).view(-1),
231-
self._unwrap_dtensor(exp_avg_sq).view(-1),
232-
self._unwrap_dtensor(max_exp_avg_sq).view(-1) if max_exp_avg_sq is not None else None,
235+
self._unwrap_dtensor(exp_avg),
236+
self._unwrap_dtensor(exp_avg_sq),
237+
self._unwrap_dtensor(max_exp_avg_sq) if max_exp_avg_sq is not None else None,
233238
lr,
234239
beta1,
235240
beta2,
@@ -296,7 +301,7 @@ def __init__(
296301

297302
@staticmethod
298303
def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
299-
return OptimState4bit.zeros(p.shape, signed, block_size, p.device)
304+
return OptimState4bit.zeros(p.view(-1).shape, signed, block_size, p.device)
300305

301306
@staticmethod
302307
def _unwrap_dtensor(p: Tensor):
@@ -314,6 +319,11 @@ def step(self, closure=None):
314319
# NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim.
315320
# thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param.
316321

322+
# NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for
323+
# PyTorch 2.3 and 2.4
324+
# calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op
325+
# correctly for the tensor subclass.
326+
317327
# unwrap DTensor since DTensor does not work well with dynamic compile
318328
# flatten p, grad, and optim state to avoid recompilation
319329
for group, lr, (beta1, beta2), weight_decay, eps in param_groups:
@@ -325,9 +335,9 @@ def step(self, closure=None):
325335
self._unwrap_dtensor(p).view(-1),
326336
self._unwrap_dtensor(grad).view(-1),
327337
step,
328-
self._unwrap_dtensor(exp_avg).view(-1),
329-
self._unwrap_dtensor(exp_avg_sq).view(-1),
330-
self._unwrap_dtensor(max_exp_avg_sq).view(-1) if max_exp_avg_sq is not None else None,
338+
self._unwrap_dtensor(exp_avg),
339+
self._unwrap_dtensor(exp_avg_sq),
340+
self._unwrap_dtensor(max_exp_avg_sq) if max_exp_avg_sq is not None else None,
331341
lr,
332342
beta1,
333343
beta2,

0 commit comments

Comments
 (0)