Skip to content

Commit 7a2a7b3

Browse files
authored
Add NPU (Ascend) backend support for INT4 weight-only quantization workflow (#3172)
* Add NPU (Ascend) backend support for INT4 weight-only quantization workflow * use torch.ops.npu prefix and drop redundant torch_npu import * Modify test file and update comments * add: merge NPU(Ascend) backend logic in Int4PlainInt32Tensor subclass * ruff format cleanup, replace error types, add torch version check * add torch_npu version assertion and show downstream testing result * add downstream testing result * unify NPU and XPU test cases into a single class * move CI display to quantization README and update test file
1 parent dbc89d3 commit 7a2a7b3

File tree

3 files changed

+292
-78
lines changed

3 files changed

+292
-78
lines changed

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import tempfile
8-
import unittest
98

9+
import pytest
1010
import torch
11+
from torch.testing._internal.common_device_type import instantiate_device_type_tests
1112
from torch.testing._internal.common_utils import (
1213
TestCase,
13-
instantiate_parametrized_tests,
1414
parametrize,
1515
run_tests,
1616
)
@@ -33,9 +33,19 @@ def get_config(group_size):
3333
)
3434

3535

36-
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
37-
@unittest.skipIf(not torch.xpu.is_available(), "XPU not available")
3836
class Int4PlainInt32Tensor(TestCase):
37+
_MIN_VER = {
38+
"xpu": "2.8.0",
39+
"npu": "2.7.1",
40+
}
41+
42+
def setUp(self):
43+
min_req = type(self)._MIN_VER.get(self.device_type)
44+
if not torch_version_at_least(min_req):
45+
self.skipTest(
46+
f"{self.device_type} requires torch >= {min_req}, current {torch.__version__}"
47+
)
48+
3949
@parametrize(
4050
"sizes",
4151
[
@@ -46,24 +56,35 @@ class Int4PlainInt32Tensor(TestCase):
4656
)
4757
@parametrize("dtype", [torch.bfloat16, torch.half])
4858
@parametrize("group_size", [32, 64, 128])
49-
def test_linear(self, sizes, dtype, group_size):
50-
device = "xpu"
59+
@parametrize("thresholds", [{"xpu": 20, "npu": 10}])
60+
def test_linear(self, device, sizes, dtype, group_size, thresholds):
5161
M, N, K = sizes
62+
if "npu" in device and group_size == K:
63+
pytest.skip(
64+
f"{device} does not support group_size equal to K dimension ({group_size} == {K})"
65+
)
66+
threshold = thresholds.get(device.split(":")[0])
67+
5268
input = torch.randn(*M, K, dtype=dtype, device=device)
5369
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
5470
original = linear(input)
5571
quantize_(linear, get_config(group_size))
5672
quantized = linear(input)
57-
self.assertTrue(compute_error(original, quantized) > 20)
73+
self.assertTrue(compute_error(original, quantized) > threshold)
5874

59-
compiled_linear = torch.compile(linear)
60-
quantized_and_compiled = compiled_linear(input)
61-
self.assertTrue(compute_error(original, quantized_and_compiled) > 20)
75+
if "xpu" in device:
76+
compiled_linear = torch.compile(linear)
77+
quantized_and_compiled = compiled_linear(input)
78+
self.assertTrue(compute_error(original, quantized_and_compiled) > threshold)
6279

6380
@parametrize("dtype", [torch.bfloat16, torch.half])
64-
def test_module_path(self, dtype):
65-
linear = torch.nn.Linear(128, 256, dtype=dtype, device="xpu")
66-
quantize_(linear, get_config(group_size=128))
81+
def test_module_path(self, device, dtype):
82+
K, N, group_size = 128, 256, 128
83+
if "npu" in device:
84+
group_size = 64
85+
86+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
87+
quantize_(linear, get_config(group_size))
6788
self.assertEqual(
6889
str(type(linear.weight)),
6990
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
@@ -78,13 +99,21 @@ def test_module_path(self, dtype):
7899
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
79100
)
80101

81-
def test_activation_prescaling(self):
82-
dtype = torch.bfloat16
83-
device = "xpu"
84-
input = torch.randn(1, 128, dtype=dtype, device=device)
85-
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
102+
@parametrize("dtype", [torch.float16, torch.bfloat16])
103+
@parametrize("thresholds", [{"xpu": 20, "npu": 10}])
104+
def test_activation_prescaling(self, device, dtype, thresholds):
105+
if "xpu" in device and dtype == torch.float16:
106+
pytest.skip(f"{device} test_activation_prescaling don't test {dtype}")
107+
108+
threshold = thresholds.get(device.split(":")[0])
109+
K, N, group_size = 128, 256, 128
110+
if "npu" in device:
111+
group_size = 64
112+
113+
input = torch.randn(1, K, dtype=dtype, device=device)
114+
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device)
86115
original = linear(input)
87-
quantize_(linear, get_config(128))
116+
quantize_(linear, get_config(group_size))
88117
qw = linear.weight
89118
assert isinstance(qw, SupportsActivationPreScaling), (
90119
"Expected int4 tensor supports activation prescaling"
@@ -95,10 +124,12 @@ def test_activation_prescaling(self):
95124
quantized = linear(input)
96125

97126
# making sure activation pre scaling is successfully applied to the activation
98-
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)
127+
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > threshold)
99128

100129

101-
instantiate_parametrized_tests(Int4PlainInt32Tensor)
130+
instantiate_device_type_tests(
131+
Int4PlainInt32Tensor, globals(), only_for=("xpu", "npu"), allow_xpu=True
132+
)
102133

103134

104135
if __name__ == "__main__":

torchao/quantization/README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,12 @@ use_hqq = False
7171
quantize_(model, Int4WeightOnlyConfig(group_size=group_size, int4_packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq"))
7272
```
7373

74-
Note: The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
75-
74+
Note:
75+
- The quantization error incurred by applying int4 quantization to your model can be fairly significant, so using external techniques like GPTQ may be necessary to obtain a usable model.
76+
- Third-party backend CI status:
77+
- Ascend NPU(requires torch_npu ≥ 2.7.1)
78+
[![Ascend NPU](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml/badge.svg)](https://github.com/Ascend/Ascend-CI/actions/workflows/torchao.yml)
79+
7680
#### A16W8 Int8 WeightOnly Quantization
7781

7882
```python

0 commit comments

Comments
 (0)