Skip to content

Commit e8662e0

Browse files
authored
Fixing cuda device check (#536)
Summary: Previous cuda device check is not general enough, this adds a better check that works for more cases like "cuda:0" Test Plan: python test/quantization/test_quant_api.py -k test_int4wo_quantized_model_to_device Reviewers: Subscribers: Tasks: Tags:
1 parent e5df48e commit e8662e0

File tree

3 files changed

+19
-13
lines changed

3 files changed

+19
-13
lines changed

test/quantization/test_quant_api.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -642,17 +642,19 @@ def test_int8wo_quantized_model_to_device(self):
642642
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "Test currently doesn't work for 2.5+")
643643
def test_int4wo_quantized_model_to_device(self):
644644
# TODO: change initial model to "cpu"
645-
m = ToyLinearModel().eval().to(torch.bfloat16).to("cuda")
646-
m_copy = copy.deepcopy(m)
647-
example_inputs = m.example_inputs(dtype=torch.bfloat16, device="cuda")
648-
649-
quantize_(m, int4_weight_only())
650-
ref = m(*example_inputs)
651-
652-
example_inputs_cuda = (example_inputs[0].to("cuda"),)
653-
m.to(device="cuda")
654-
cuda_res = m(*example_inputs_cuda)
655-
self.assertEqual(cuda_res.cpu(), ref)
645+
devices = ["cuda", "cuda:0"]
646+
for device in devices:
647+
m = ToyLinearModel().eval().to(torch.bfloat16).to(device)
648+
m_copy = copy.deepcopy(m)
649+
example_inputs = m.example_inputs(dtype=torch.bfloat16, device=device)
650+
651+
quantize_(m, int4_weight_only())
652+
ref = m(*example_inputs)
653+
654+
example_inputs_cuda = (example_inputs[0].to(device),)
655+
m.to(device=device)
656+
cuda_res = m(*example_inputs_cuda)
657+
self.assertEqual(cuda_res.cpu(), ref)
656658

657659
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
658660
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
_register_layout_cls,
2222
_get_layout_tensor_constructor,
2323
LayoutType,
24+
is_device,
2425
)
2526
from typing import ClassVar
2627
from dataclasses import dataclass
@@ -544,7 +545,7 @@ def from_plain(
544545
def to(self, *args, **kwargs):
545546
kwargs = self._get_to_kwargs(*args, **kwargs)
546547
device = kwargs["device"]
547-
if device != "cuda" and (isinstance(device, torch.device) and device.type != "cuda"):
548+
if not is_device("cuda", device):
548549
raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}")
549550
return self.__class__(
550551
self.packed_weight.to(device),

torchao/dtypes/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from typing import Dict, Callable
2+
from typing import Dict, Callable, Union
33
from collections import defaultdict
44
import functools
55
from dataclasses import dataclass
@@ -89,3 +89,6 @@ def _get_layout_tensor_constructor(cls: Callable, layout_type_class: type(Layout
8989
raise ValueError(f"layout_name: {layout_type_class} is not supported yet for {cls}")
9090

9191
return _LAYOUT_CONSTRUCTOR_TABLE[cls][layout_type_class]
92+
93+
def is_device(target_device_str: str, device: Union[str, torch.device]):
94+
return torch.device(device).type == target_device_str

0 commit comments

Comments
 (0)