Skip to content

Commit 6b71049

Browse files
committed
unify NPU and XPU test cases into a single class
1 parent 330d397 commit 6b71049

File tree

1 file changed

+26
-51
lines changed

1 file changed

+26
-51
lines changed

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

Lines changed: 26 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,13 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import pytest
87
import tempfile
9-
import unittest
108

9+
import pytest
1110
import torch
11+
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1212
from torch.testing._internal.common_utils import (
1313
TestCase,
14-
instantiate_parametrized_tests,
1514
parametrize,
1615
run_tests,
1716
)
@@ -34,49 +33,22 @@ def get_config(group_size):
3433
)
3534

3635

36+
_ALL_DEVICES = ("xpu", "npu")
3737
_MIN_VER = {
3838
"xpu": "2.8.0",
3939
"npu": "2.7.1",
4040
}
41-
THRESHOLD = {"xpu": 20, "npu": 10}
42-
43-
ALL_DEVICES = ("xpu", "npu")
44-
45-
46-
def _get_available_devices() -> tuple[list[str], list[str]]:
47-
available_devices = []
48-
messages = []
49-
for name in ALL_DEVICES:
50-
mod = getattr(torch, name, None)
51-
if mod is None:
52-
messages.append(f"{name}: not found in torch")
53-
continue
54-
avail = mod.is_available()
55-
status = []
56-
status.append(f"available={avail}")
57-
status.append(f"min_version_req={_MIN_VER[name]}")
58-
status.append(f"torch_version={torch.__version__}")
59-
if avail and torch_version_at_least(_MIN_VER[name]):
60-
available_devices.append(name)
61-
status.append("OK")
62-
else:
63-
status.append("FAIL")
64-
messages.append(f"{name}: " + ", ".join(status))
65-
66-
return available_devices, messages
67-
68-
69-
AVAILABLE_DEVICES, MESSAGES = _get_available_devices()
70-
print("\nDevice Status:")
71-
for msg in MESSAGES:
72-
print(" ", msg)
73-
74-
75-
@unittest.skipIf(
76-
not AVAILABLE_DEVICES, f"No available devices: {', '.join(ALL_DEVICES)}"
77-
)
41+
_THRESHOLD = {"xpu": 20, "npu": 10}
42+
43+
7844
class Int4PlainInt32Tensor(TestCase):
79-
@parametrize("device", AVAILABLE_DEVICES)
45+
def setUp(self):
46+
min_req = _MIN_VER.get(self.device_type)
47+
if not torch_version_at_least(min_req):
48+
self.skipTest(
49+
f"{self.device_type} requires torch >= {min_req}, current {torch.__version__}"
50+
)
51+
8052
@parametrize(
8153
"sizes",
8254
[
@@ -87,13 +59,14 @@ class Int4PlainInt32Tensor(TestCase):
8759
)
8860
@parametrize("dtype", [torch.bfloat16, torch.half])
8961
@parametrize("group_size", [32, 64, 128])
90-
def test_linear(self, device, sizes, dtype, group_size):
62+
def test_linear(self, sizes, dtype, group_size):
63+
device = self.device_type
9164
M, N, K = sizes
9265
if device == "npu" and group_size == K:
9366
pytest.skip(
9467
f"{device} does not support group_size equal to K dimension ({group_size} == {K})"
9568
)
96-
threshold = THRESHOLD.get(device)
69+
threshold = _THRESHOLD.get(device)
9770

9871
input = torch.randn(*M, K, dtype=dtype, device=device)
9972
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
@@ -107,9 +80,9 @@ def test_linear(self, device, sizes, dtype, group_size):
10780
quantized_and_compiled = compiled_linear(input)
10881
self.assertTrue(compute_error(original, quantized_and_compiled) > threshold)
10982

110-
@parametrize("device", AVAILABLE_DEVICES)
11183
@parametrize("dtype", [torch.bfloat16, torch.half])
112-
def test_module_path(self, device, dtype):
84+
def test_module_path(self, dtype):
85+
device = self.device_type
11386
K, N, group_size = 128, 256, 128
11487
if device == "npu":
11588
group_size = 64
@@ -130,19 +103,19 @@ def test_module_path(self, device, dtype):
130103
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
131104
)
132105

133-
@parametrize("device", AVAILABLE_DEVICES)
134106
@parametrize("dtype", [torch.float16, torch.bfloat16])
135-
def test_activation_prescaling(self, device, dtype):
107+
def test_activation_prescaling(self, dtype):
108+
device = self.device_type
136109
if device == "xpu" and dtype == torch.float16:
137110
pytest.skip(f"{device} test_activation_prescaling don't test {dtype}")
138111

139-
threshold = THRESHOLD.get(device)
112+
threshold = _THRESHOLD.get(device)
140113
K, N, group_size = 128, 256, 128
141114
if device == "npu":
142115
group_size = 64
143116

144-
input = torch.randn(1, 128, dtype=dtype, device=device)
145-
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
117+
input = torch.randn(1, K, dtype=dtype, device=device)
118+
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device)
146119
original = linear(input)
147120
quantize_(linear, get_config(group_size))
148121
qw = linear.weight
@@ -158,7 +131,9 @@ def test_activation_prescaling(self, device, dtype):
158131
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > threshold)
159132

160133

161-
instantiate_parametrized_tests(Int4PlainInt32Tensor)
134+
instantiate_device_type_tests(
135+
Int4PlainInt32Tensor, globals(), only_for=_ALL_DEVICES, allow_xpu=True
136+
)
162137

163138

164139
if __name__ == "__main__":

0 commit comments

Comments
 (0)