Skip to content

Commit 0c24ee3

Browse files
committed
Enable model.to(device) for int8 weight only quantized model
Summary: Fix some implementation issue for `int8_wo_quantized_model.to(device)` Test Plan: python test/quantization/test_quant_api.py -k test_quantized_model_to_device Reviewers: Subscribers: Tasks: Tags:
1 parent 12ac498 commit 0c24ee3

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

test/quantization/test_quant_api.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,5 +620,21 @@ def test_quantized_tensor_subclass_save_load(self):
620620
self.assertEqual(res, ref)
621621

622622

623+
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
624+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
625+
def test_quantized_model_to_device(self):
626+
m = ToyLinearModel().eval().to(torch.bfloat16)
627+
m_copy = copy.deepcopy(m)
628+
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cpu")
629+
630+
quantize_(m, int8_weight_only())
631+
ref = m(*example_inputs)
632+
633+
example_inputs_cuda = (example_inputs[0].to("cuda"),)
634+
m.to(device="cuda")
635+
cuda_res = m(*example_inputs_cuda)
636+
self.assertEqual(cuda_res.cpu(), ref)
637+
638+
623639
if __name__ == "__main__":
624640
unittest.main()

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,11 @@ def _get_to_kwargs(self, *args, **kwargs):
259259

260260
def to(self, *args, **kwargs):
261261
kwargs = self._get_to_kwargs(*args, **kwargs)
262+
device = kwargs.pop("device")
263+
# not supported yet
264+
kwargs.pop("memory_format")
262265
return self.__class__(
263-
self.layout_tensor.to(kwargs["device"]),
266+
self.layout_tensor.to(device),
264267
self.block_size,
265268
self.shape,
266269
self.quant_min,
@@ -470,8 +473,8 @@ def to(self, *args, **kwargs):
470473
if device != "cuda" or (isinstance(device, torch.device) and device.type != "cuda"):
471474
raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device")
472475
return self.__class__(
473-
self.packed_weight.to(kwargs["device"]),
474-
self.scale_and_zero.to(kwargs["device"]),
476+
self.packed_weight.to(device),
477+
self.scale_and_zero.to(device),
475478
self.transposed
476479
)
477480

0 commit comments

Comments
 (0)