Skip to content

Commit fa3220f

Browse files
committed
unify NPU and XPU test cases into a single class
1 parent 25360da commit fa3220f

File tree

1 file changed

+52
-85
lines changed

1 file changed

+52
-85
lines changed

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

Lines changed: 52 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import tempfile
8-
import unittest
98

9+
import pytest
1010
import torch
11+
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1112
from torch.testing._internal.common_utils import (
1213
TestCase,
13-
instantiate_parametrized_tests,
1414
parametrize,
1515
run_tests,
1616
)
@@ -33,9 +33,19 @@ def get_config(group_size):
3333
)
3434

3535

36-
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
37-
@unittest.skipIf(not torch.xpu.is_available(), "XPU not available")
38-
class Int4PlainInt32TensorXPU(TestCase):
36+
class Int4PlainInt32Tensor(TestCase):
37+
_MIN_VER = {
38+
"xpu": "2.8.0",
39+
"npu": "2.7.1",
40+
}
41+
42+
def setUp(self):
43+
min_req = type(self)._MIN_VER.get(self.device_type)
44+
if not torch_version_at_least(min_req):
45+
self.skipTest(
46+
f"{self.device_type} requires torch >= {min_req}, current {torch.__version__}"
47+
)
48+
3949
@parametrize(
4050
"sizes",
4151
[
@@ -46,90 +56,36 @@ class Int4PlainInt32TensorXPU(TestCase):
4656
)
4757
@parametrize("dtype", [torch.bfloat16, torch.half])
4858
@parametrize("group_size", [32, 64, 128])
49-
def test_linear(self, sizes, dtype, group_size):
50-
device = "xpu"
59+
@parametrize("thresholds", [{"xpu": 20, "npu": 10}])
60+
def test_linear(self, device, sizes, dtype, group_size, thresholds):
5161
M, N, K = sizes
62+
if "npu" in device and group_size == K:
63+
pytest.skip(
64+
f"{device} does not support group_size equal to K dimension ({group_size} == {K})"
65+
)
66+
threshold = thresholds.get(device.split(":")[0])
67+
5268
input = torch.randn(*M, K, dtype=dtype, device=device)
5369
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
5470
original = linear(input)
5571
quantize_(linear, get_config(group_size))
5672
quantized = linear(input)
57-
self.assertTrue(compute_error(original, quantized) > 20)
73+
self.assertTrue(compute_error(original, quantized) > threshold)
5874

59-
compiled_linear = torch.compile(linear)
60-
quantized_and_compiled = compiled_linear(input)
61-
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
75+
if "xpu" in device:
76+
compiled_linear = torch.compile(linear)
77+
quantized_and_compiled = compiled_linear(input)
78+
self.assertTrue(compute_error(original, quantized_and_compiled) > threshold)
6279

6380
@parametrize("dtype", [torch.bfloat16, torch.half])
64-
def test_module_path(self, dtype):
65-
linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu")
66-
quantize_(linear, get_config(group_size=128))
67-
self.assertEqual(
68-
str(type(linear.weight)),
69-
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
70-
)
71-
72-
with tempfile.NamedTemporaryFile() as f:
73-
torch.save(linear.state_dict(), f)
74-
f.seek(0)
75-
state_dict = torch.load(f)
76-
self.assertEqual(
77-
str(type(state_dict["weight"])),
78-
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
79-
)
80-
81-
def test_activation_prescaling(self):
82-
dtype = torch.bfloat16
83-
device = "xpu"
84-
input = torch.randn(1, 128, dtype=dtype, device=device)
85-
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
86-
original = linear(input)
87-
quantize_(linear, get_config(128))
88-
qw = linear.weight
89-
assert isinstance(qw, SupportsActivationPreScaling), (
90-
"Expected int4 tensor supports activation prescaling"
91-
)
92-
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
93-
_ACT_PRE_SCALE = 2
94-
qw.act_pre_scale = _ACT_PRE_SCALE
95-
quantized = linear(input)
96-
97-
# making sure activation pre scaling is successfully applied to the activation
98-
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)
99-
81+
def test_module_path(self, device, dtype):
82+
device = self.device_type
83+
K, N, group_size = 128, 256, 128
84+
if "npu" in device:
85+
group_size = 64
10086

101-
@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+")
102-
@unittest.skipIf(
103-
torch.accelerator.current_accelerator().type != "npu"
104-
or not torch.accelerator.is_available(),
105-
"NPU not available",
106-
)
107-
class Int4PlainInt32TensorNPU(TestCase):
108-
@parametrize("device", ["npu"])
109-
@parametrize(
110-
"sizes",
111-
[
112-
((128,), 256, 128),
113-
((32, 128), 512, 128),
114-
((2, 32, 128), 256, 128),
115-
],
116-
)
117-
@parametrize("dtype", [torch.float16, torch.bfloat16])
118-
@parametrize("group_size", [32, 64])
119-
def test_linear(self, device, sizes, dtype, group_size):
120-
M, N, K = sizes
121-
input = torch.randn(*M, K, dtype=dtype, device=device)
12287
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
123-
orig_output = linear(input)
12488
quantize_(linear, get_config(group_size))
125-
quantized_output = linear(input)
126-
self.assertTrue(compute_error(orig_output, quantized_output) > 10)
127-
128-
@parametrize("device", ["npu"])
129-
@parametrize("dtype", [torch.float16, torch.bfloat16])
130-
def test_module_path(self, device, dtype):
131-
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
132-
quantize_(linear, get_config(group_size=64))
13389
self.assertEqual(
13490
str(type(linear.weight)),
13591
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
@@ -144,13 +100,22 @@ def test_module_path(self, device, dtype):
144100
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
145101
)
146102

147-
@parametrize("device", ["npu"])
148103
@parametrize("dtype", [torch.float16, torch.bfloat16])
149-
def test_activation_prescaling(self, device, dtype):
150-
input = torch.randn(1, 128, dtype=dtype, device=device)
151-
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
104+
@parametrize("thresholds", [{"xpu": 20, "npu": 10}])
105+
def test_activation_prescaling(self, device, dtype, thresholds):
106+
device = self.device_type
107+
if "xpu" in device and dtype == torch.float16:
108+
pytest.skip(f"{device} test_activation_prescaling don't test {dtype}")
109+
110+
threshold = thresholds.get(device.split(":")[0])
111+
K, N, group_size = 128, 256, 128
112+
if "npu" in device:
113+
group_size = 64
114+
115+
input = torch.randn(1, K, dtype=dtype, device=device)
116+
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device)
152117
original = linear(input)
153-
quantize_(linear, get_config(64))
118+
quantize_(linear, get_config(group_size))
154119
qw = linear.weight
155120
assert isinstance(qw, SupportsActivationPreScaling), (
156121
"Expected int4 tensor supports activation prescaling"
@@ -161,11 +126,13 @@ def test_activation_prescaling(self, device, dtype):
161126
quantized = linear(input)
162127

163128
# making sure activation pre scaling is successfully applied to the activation
164-
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10)
129+
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > threshold)
130+
165131

132+
instantiate_device_type_tests(
133+
Int4PlainInt32Tensor, globals(), only_for=("xpu", "npu"), allow_xpu=True
134+
)
166135

167-
instantiate_parametrized_tests(Int4PlainInt32TensorXPU)
168-
instantiate_parametrized_tests(Int4PlainInt32TensorNPU)
169136

170137
if __name__ == "__main__":
171138
run_tests()

0 commit comments

Comments
 (0)