Skip to content

Commit 2c7b199

Browse files
committed
unify NPU and XPU test cases into a single class
1 parent 25360da commit 2c7b199

File tree

1 file changed

+73
-79
lines changed

1 file changed

+73
-79
lines changed

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

Lines changed: 73 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import pytest
78
import tempfile
89
import unittest
910

@@ -33,103 +34,88 @@ def get_config(group_size):
3334
)
3435

3536

36-
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
37-
@unittest.skipIf(not torch.xpu.is_available(), "XPU not available")
38-
class Int4PlainInt32TensorXPU(TestCase):
39-
@parametrize(
40-
"sizes",
41-
[
42-
((128,), 256, 128),
43-
((32, 128), 512, 128),
44-
((2, 32, 128), 256, 12),
45-
],
46-
)
47-
@parametrize("dtype", [torch.bfloat16, torch.half])
48-
@parametrize("group_size", [32, 64, 128])
49-
def test_linear(self, sizes, dtype, group_size):
50-
device = "xpu"
51-
M, N, K = sizes
52-
input = torch.randn(*M, K, dtype=dtype, device=device)
53-
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
54-
original = linear(input)
55-
quantize_(linear, get_config(group_size))
56-
quantized = linear(input)
57-
self.assertTrue(compute_error(original, quantized) > 20)
37+
_MIN_VER = {
38+
"xpu": "2.8.0",
39+
"npu": "2.7.1",
40+
}
41+
THRESHOLD = {"xpu": 20, "npu": 10}
5842

59-
compiled_linear = torch.compile(linear)
60-
quantized_and_compiled = compiled_linear(input)
61-
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
43+
ALL_DEVICES = ("xpu", "npu")
6244

63-
@parametrize("dtype", [torch.bfloat16, torch.half])
64-
def test_module_path(self, dtype):
65-
linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu")
66-
quantize_(linear, get_config(group_size=128))
67-
self.assertEqual(
68-
str(type(linear.weight)),
69-
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
70-
)
7145

72-
with tempfile.NamedTemporaryFile() as f:
73-
torch.save(linear.state_dict(), f)
74-
f.seek(0)
75-
state_dict = torch.load(f)
76-
self.assertEqual(
77-
str(type(state_dict["weight"])),
78-
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
79-
)
46+
def _get_available_devices() -> tuple[list[str], list[str]]:
47+
available_devices = []
48+
messages = []
49+
for name in ALL_DEVICES:
50+
mod = getattr(torch, name, None)
51+
if mod is None:
52+
messages.append(f"{name}: not found in torch")
53+
continue
54+
avail = mod.is_available()
55+
status = []
56+
status.append(f"available={avail}")
57+
status.append(f"min_version_req={_MIN_VER[name]}")
58+
status.append(f"torch_version={torch.__version__}")
59+
if avail and torch_version_at_least(_MIN_VER[name]):
60+
available_devices.append(name)
61+
status.append("OK")
62+
else:
63+
status.append("FAIL")
64+
messages.append(f"{name}: " + ", ".join(status))
8065

81-
def test_activation_prescaling(self):
82-
dtype = torch.bfloat16
83-
device = "xpu"
84-
input = torch.randn(1, 128, dtype=dtype, device=device)
85-
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
86-
original = linear(input)
87-
quantize_(linear, get_config(128))
88-
qw = linear.weight
89-
assert isinstance(qw, SupportsActivationPreScaling), (
90-
"Expected int4 tensor supports activation prescaling"
91-
)
92-
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
93-
_ACT_PRE_SCALE = 2
94-
qw.act_pre_scale = _ACT_PRE_SCALE
95-
quantized = linear(input)
66+
return available_devices, messages
9667

97-
# making sure activation pre scaling is successfully applied to the activation
98-
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)
68+
69+
AVAILABLE_DEVICES, MESSAGES = _get_available_devices()
70+
print("\nDevice Status:")
71+
for msg in MESSAGES:
72+
print(" ", msg)
9973

10074

101-
@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+")
10275
@unittest.skipIf(
103-
torch.accelerator.current_accelerator().type != "npu"
104-
or not torch.accelerator.is_available(),
105-
"NPU not available",
76+
not AVAILABLE_DEVICES, f"No available devices: {', '.join(ALL_DEVICES)}"
10677
)
107-
class Int4PlainInt32TensorNPU(TestCase):
108-
@parametrize("device", ["npu"])
78+
class Int4PlainInt32Tensor(TestCase):
79+
@parametrize("device", AVAILABLE_DEVICES)
10980
@parametrize(
11081
"sizes",
11182
[
11283
((128,), 256, 128),
11384
((32, 128), 512, 128),
114-
((2, 32, 128), 256, 128),
85+
((2, 32, 128), 256, 12),
11586
],
11687
)
117-
@parametrize("dtype", [torch.float16, torch.bfloat16])
118-
@parametrize("group_size", [32, 64])
88+
@parametrize("dtype", [torch.bfloat16, torch.half])
89+
@parametrize("group_size", [32, 64, 128])
11990
def test_linear(self, device, sizes, dtype, group_size):
12091
M, N, K = sizes
92+
if device == "npu" and group_size == K:
93+
pytest.skip(
94+
f"{device} does not support group_size equal to K dimension ({group_size} == {K})"
95+
)
96+
threshold = THRESHOLD.get(device)
97+
12198
input = torch.randn(*M, K, dtype=dtype, device=device)
12299
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
123-
orig_output = linear(input)
100+
original = linear(input)
124101
quantize_(linear, get_config(group_size))
125-
quantized_output = linear(input)
126-
self.assertTrue(compute_error(orig_output, quantized_output) > 10)
102+
quantized = linear(input)
103+
self.assertTrue(compute_error(original, quantized) > threshold)
127104

128-
@parametrize("device", ["npu"])
129-
@parametrize("dtype", [torch.float16, torch.bfloat16])
105+
if device == "xpu":
106+
compiled_linear = torch.compile(linear)
107+
quantized_and_compiled = compiled_linear(input)
108+
self.assertTrue(compute_error(original, quantized_and_compiled) > threshold)
109+
110+
@parametrize("device", AVAILABLE_DEVICES)
111+
@parametrize("dtype", [torch.bfloat16, torch.half])
130112
def test_module_path(self, device, dtype):
131-
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
132-
quantize_(linear, get_config(group_size=64))
113+
K, N, group_size = 128, 256, 128
114+
if device == "npu":
115+
group_size = 64
116+
117+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
118+
quantize_(linear, get_config(group_size))
133119
self.assertEqual(
134120
str(type(linear.weight)),
135121
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
@@ -144,13 +130,21 @@ def test_module_path(self, device, dtype):
144130
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
145131
)
146132

147-
@parametrize("device", ["npu"])
133+
@parametrize("device", AVAILABLE_DEVICES)
148134
@parametrize("dtype", [torch.float16, torch.bfloat16])
149135
def test_activation_prescaling(self, device, dtype):
136+
if device == "xpu" and dtype == torch.float16:
137+
pytest.skip(f"{device} test_activation_prescaling don't test {dtype}")
138+
139+
threshold = THRESHOLD.get(device)
140+
K, N, group_size = 128, 256, 128
141+
if device == "npu":
142+
group_size = 64
143+
150144
input = torch.randn(1, 128, dtype=dtype, device=device)
151145
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
152146
original = linear(input)
153-
quantize_(linear, get_config(64))
147+
quantize_(linear, get_config(group_size))
154148
qw = linear.weight
155149
assert isinstance(qw, SupportsActivationPreScaling), (
156150
"Expected int4 tensor supports activation prescaling"
@@ -161,11 +155,11 @@ def test_activation_prescaling(self, device, dtype):
161155
quantized = linear(input)
162156

163157
# making sure activation pre scaling is successfully applied to the activation
164-
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10)
158+
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > threshold)
159+
165160

161+
instantiate_parametrized_tests(Int4PlainInt32Tensor)
166162

167-
instantiate_parametrized_tests(Int4PlainInt32TensorXPU)
168-
instantiate_parametrized_tests(Int4PlainInt32TensorNPU)
169163

170164
if __name__ == "__main__":
171165
run_tests()

0 commit comments

Comments
 (0)