@@ -198,7 +198,7 @@ def __init__(
198
198
199
199
@staticmethod
200
200
def _subclass_zeros (p : Tensor , signed : bool , block_size : int ):
201
- return OptimState4bit .zeros (p .shape , signed , block_size , p .device )
201
+ return OptimState4bit .zeros (p .view ( - 1 ). shape , signed , block_size , p .device )
202
202
203
203
@staticmethod
204
204
def _unwrap_dtensor (p : Tensor ):
@@ -216,6 +216,11 @@ def step(self, closure=None):
216
216
# NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim.
217
217
# thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param.
218
218
219
+ # NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for
220
+ # PyTorch 2.3 and 2.4
221
+ # calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op
222
+ # correctly for the tensor subclass.
223
+
219
224
# unwrap DTensor since DTensor does not work well with dynamic compile
220
225
# flatten p, grad, and optim state to avoid recompilation
221
226
for group , lr , (beta1 , beta2 ), weight_decay , eps in param_groups :
@@ -227,9 +232,9 @@ def step(self, closure=None):
227
232
self ._unwrap_dtensor (p ).view (- 1 ),
228
233
self ._unwrap_dtensor (grad ).view (- 1 ),
229
234
step ,
230
- self ._unwrap_dtensor (exp_avg ). view ( - 1 ) ,
231
- self ._unwrap_dtensor (exp_avg_sq ). view ( - 1 ) ,
232
- self ._unwrap_dtensor (max_exp_avg_sq ). view ( - 1 ) if max_exp_avg_sq is not None else None ,
235
+ self ._unwrap_dtensor (exp_avg ),
236
+ self ._unwrap_dtensor (exp_avg_sq ),
237
+ self ._unwrap_dtensor (max_exp_avg_sq ) if max_exp_avg_sq is not None else None ,
233
238
lr ,
234
239
beta1 ,
235
240
beta2 ,
@@ -296,7 +301,7 @@ def __init__(
296
301
297
302
@staticmethod
298
303
def _subclass_zeros (p : Tensor , signed : bool , block_size : int ):
299
- return OptimState4bit .zeros (p .shape , signed , block_size , p .device )
304
+ return OptimState4bit .zeros (p .view ( - 1 ). shape , signed , block_size , p .device )
300
305
301
306
@staticmethod
302
307
def _unwrap_dtensor (p : Tensor ):
@@ -314,6 +319,11 @@ def step(self, closure=None):
314
319
# NOTE: right now, torch.compile(param_groups_adam) will have excessive memory usage for 4-bit optim.
315
320
# thus, as a workaround, we use torch.compile(single_param_adam) and call it for each param.
316
321
322
+ # NOTE: we have to create flattened optimizer states since torch.compile() will fail otherwise for
323
+ # PyTorch 2.3 and 2.4
324
+ # calling exp_avg.view(-1) will fail torch.compile(single_param_adam) even if we implement the op
325
+ # correctly for the tensor subclass.
326
+
317
327
# unwrap DTensor since DTensor does not work well with dynamic compile
318
328
# flatten p, grad, and optim state to avoid recompilation
319
329
for group , lr , (beta1 , beta2 ), weight_decay , eps in param_groups :
@@ -325,9 +335,9 @@ def step(self, closure=None):
325
335
self ._unwrap_dtensor (p ).view (- 1 ),
326
336
self ._unwrap_dtensor (grad ).view (- 1 ),
327
337
step ,
328
- self ._unwrap_dtensor (exp_avg ). view ( - 1 ) ,
329
- self ._unwrap_dtensor (exp_avg_sq ). view ( - 1 ) ,
330
- self ._unwrap_dtensor (max_exp_avg_sq ). view ( - 1 ) if max_exp_avg_sq is not None else None ,
338
+ self ._unwrap_dtensor (exp_avg ),
339
+ self ._unwrap_dtensor (exp_avg_sq ),
340
+ self ._unwrap_dtensor (max_exp_avg_sq ) if max_exp_avg_sq is not None else None ,
331
341
lr ,
332
342
beta1 ,
333
343
beta2 ,
0 commit comments