Skip to content

Commit 38b1f49

Browse files
committed
unify NPU and XPU test cases into a single class
1 parent 6b71049 commit 38b1f49

File tree

1 file changed

+19
-21
lines changed

1 file changed

+19
-21
lines changed

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,14 @@ def get_config(group_size):
3333
)
3434

3535

36-
_ALL_DEVICES = ("xpu", "npu")
37-
_MIN_VER = {
38-
"xpu": "2.8.0",
39-
"npu": "2.7.1",
40-
}
41-
_THRESHOLD = {"xpu": 20, "npu": 10}
42-
43-
4436
class Int4PlainInt32Tensor(TestCase):
37+
_MIN_VER = {
38+
"xpu": "2.8.0",
39+
"npu": "2.7.1",
40+
}
41+
4542
def setUp(self):
46-
min_req = _MIN_VER.get(self.device_type)
43+
min_req = type(self)._MIN_VER.get(self.device_type)
4744
if not torch_version_at_least(min_req):
4845
self.skipTest(
4946
f"{self.device_type} requires torch >= {min_req}, current {torch.__version__}"
@@ -59,14 +56,14 @@ def setUp(self):
5956
)
6057
@parametrize("dtype", [torch.bfloat16, torch.half])
6158
@parametrize("group_size", [32, 64, 128])
62-
def test_linear(self, sizes, dtype, group_size):
63-
device = self.device_type
59+
@parametrize("thresholds", [{"xpu": 20, "npu": 10}])
60+
def test_linear(self, device, sizes, dtype, group_size, thresholds):
6461
M, N, K = sizes
65-
if device == "npu" and group_size == K:
62+
if "npu" in device and group_size == K:
6663
pytest.skip(
6764
f"{device} does not support group_size equal to K dimension ({group_size} == {K})"
6865
)
69-
threshold = _THRESHOLD.get(device)
66+
threshold = thresholds.get(device.split(":")[0])
7067

7168
input = torch.randn(*M, K, dtype=dtype, device=device)
7269
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
@@ -75,16 +72,16 @@ def test_linear(self, sizes, dtype, group_size):
7572
quantized = linear(input)
7673
self.assertTrue(compute_error(original, quantized) > threshold)
7774

78-
if device == "xpu":
75+
if "xpu" in device:
7976
compiled_linear = torch.compile(linear)
8077
quantized_and_compiled = compiled_linear(input)
8178
self.assertTrue(compute_error(original, quantized_and_compiled) > threshold)
8279

8380
@parametrize("dtype", [torch.bfloat16, torch.half])
84-
def test_module_path(self, dtype):
81+
def test_module_path(self, device, dtype):
8582
device = self.device_type
8683
K, N, group_size = 128, 256, 128
87-
if device == "npu":
84+
if "npu" in device:
8885
group_size = 64
8986

9087
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
@@ -104,14 +101,15 @@ def test_module_path(self, dtype):
104101
)
105102

106103
@parametrize("dtype", [torch.float16, torch.bfloat16])
107-
def test_activation_prescaling(self, dtype):
104+
@parametrize("thresholds", [{"xpu": 20, "npu": 10}])
105+
def test_activation_prescaling(self, device, dtype, thresholds):
108106
device = self.device_type
109-
if device == "xpu" and dtype == torch.float16:
107+
if "xpu" in device and dtype == torch.float16:
110108
pytest.skip(f"{device} test_activation_prescaling don't test {dtype}")
111109

112-
threshold = _THRESHOLD.get(device)
110+
threshold = thresholds.get(device.split(":")[0])
113111
K, N, group_size = 128, 256, 128
114-
if device == "npu":
112+
if "npu" in device:
115113
group_size = 64
116114

117115
input = torch.randn(1, K, dtype=dtype, device=device)
@@ -132,7 +130,7 @@ def test_activation_prescaling(self, dtype):
132130

133131

134132
instantiate_device_type_tests(
135-
Int4PlainInt32Tensor, globals(), only_for=_ALL_DEVICES, allow_xpu=True
133+
Int4PlainInt32Tensor, globals(), only_for=("xpu", "npu"), allow_xpu=True
136134
)
137135

138136

0 commit comments

Comments
 (0)