Skip to content

Commit f3aefca

Browse files
committed
Add NPU (Ascend) backend support for INT4 weight-only quantization workflow
1 parent f64daac commit f3aefca

File tree

5 files changed

+372
-4
lines changed

5 files changed

+372
-4
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
import tempfile
9+
from packaging import version
10+
11+
import torch
12+
from torch.testing._internal.common_utils import (
13+
TestCase,
14+
instantiate_parametrized_tests,
15+
parametrize,
16+
run_tests,
17+
)
18+
19+
from torchao.quantization import (
20+
Int4WeightOnlyConfig,
21+
quantize_,
22+
)
23+
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
24+
from torchao.quantization.utils import compute_error
25+
from torchao.utils import (
26+
torch_version_at_least,
27+
)
28+
29+
try:
30+
import torch_npu
31+
except ImportError:
32+
torch_npu = None
33+
34+
35+
def get_config(group_size):
36+
return Int4WeightOnlyConfig(
37+
group_size=group_size,
38+
int4_packing_format="plain_int32",
39+
)
40+
41+
42+
@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+")
43+
@unittest.skipIf(torch_npu is None, "torch_npu is not available")
44+
@unittest.skipIf(not torch_npu.npu.is_available(), "NPU not available")
45+
@unittest.skipIf(
46+
version.parse(torch_npu.__version__) < version.parse("2.7.1rc1"),
47+
"Need torch_npu 2.7.1rc1+",
48+
)
49+
class Int4PlainInt32TensorNPU(TestCase):
50+
51+
@parametrize("device", ["npu"])
52+
@parametrize(
53+
"sizes",
54+
[
55+
((128,), 256, 128),
56+
((32, 128), 512, 128),
57+
((2, 32, 128), 256, 128),
58+
],
59+
)
60+
@parametrize("dtype", [torch.float16, torch.bfloat16])
61+
@parametrize("group_size", [32, 64])
62+
def test_linear(self, device, sizes, dtype, group_size):
63+
M, N, K = sizes
64+
input = torch.randn(*M, K, dtype=dtype, device=device)
65+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
66+
orig_output = linear(input)
67+
quantize_(linear, get_config(group_size))
68+
quantized_output = linear(input)
69+
self.assertTrue(compute_error(orig_output, quantized_output) > 10)
70+
71+
@parametrize("device", ["npu"])
72+
@parametrize("dtype", [torch.float16, torch.bfloat16])
73+
def test_module_path(self, device, dtype):
74+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
75+
quantize_(linear, get_config(group_size=64))
76+
self.assertEqual(
77+
str(type(linear.weight)),
78+
"<class 'torchao.quantization.Int4PlainInt32TensorNPU'>",
79+
)
80+
81+
with tempfile.NamedTemporaryFile() as f:
82+
torch.save(linear.state_dict(), f)
83+
f.seek(0)
84+
state_dict = torch.load(f)
85+
self.assertEqual(
86+
str(type(state_dict["weight"])),
87+
"<class 'torchao.quantization.Int4PlainInt32TensorNPU'>",
88+
)
89+
90+
@parametrize("device", ["npu"])
91+
@parametrize("dtype", [torch.float16, torch.bfloat16])
92+
def test_activation_prescaling(self, device, dtype):
93+
input = torch.randn(1, 128, dtype=dtype, device=device)
94+
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
95+
original = linear(input)
96+
quantize_(linear, get_config(64))
97+
qw = linear.weight
98+
assert isinstance(
99+
qw, SupportsActivationPreScaling
100+
), "Expected int4 tensor supports activation prescaling"
101+
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
102+
_ACT_PRE_SCALE = 2
103+
qw.act_pre_scale = _ACT_PRE_SCALE
104+
quantized = linear(input)
105+
106+
# making sure activation pre scaling is successfully applied to the activation
107+
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10)
108+
109+
110+
instantiate_parametrized_tests(Int4PlainInt32TensorNPU)
111+
112+
if __name__ == "__main__":
113+
run_tests()

torchao/quantization/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@
9494
Int4MarlinSparseTensor,
9595
Int4OpaqueTensor,
9696
Int4PlainInt32Tensor,
97+
Int4PlainInt32TensorNPU,
9798
Int4PreshuffledTensor,
9899
Int4Tensor,
99100
Int4TilePackedTo4dTensor,

torchao/quantization/quant_api.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@
7878
Int4OpaqueTensor,
7979
Int4PackingFormat,
8080
Int4PlainInt32Tensor,
81+
Int4PlainInt32TensorNPU,
8182
Int4PreshuffledTensor,
8283
Int4Tensor,
8384
Int4TilePackedTo4dTensor,
@@ -1210,10 +1211,16 @@ def _int4_weight_only_quantize_tensor(weight, config):
12101211
)
12111212
return new_weight
12121213
elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
1213-
new_weight = Int4PlainInt32Tensor.from_hp(
1214-
weight,
1215-
block_size,
1216-
)
1214+
if weight.device.type == "npu":
1215+
new_weight = Int4PlainInt32TensorNPU.from_hp(
1216+
weight,
1217+
block_size,
1218+
)
1219+
else:
1220+
new_weight = Int4PlainInt32Tensor.from_hp(
1221+
weight,
1222+
block_size,
1223+
)
12171224
return new_weight
12181225
elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE:
12191226
new_weight = Int4MarlinSparseTensor.from_hp(

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
from .int4.int4_plain_int32_tensor import (
1414
Int4PlainInt32Tensor,
1515
)
16+
from .int4.int4_plain_int32_tensor_npu import (
17+
Int4PlainInt32TensorNPU,
18+
)
1619
from .int4.int4_preshuffled_tensor import (
1720
Int4PreshuffledTensor,
1821
)
@@ -36,6 +39,7 @@
3639
"Int4PreshuffledTensor",
3740
"Int4MarlinSparseTensor",
3841
"Int4PlainInt32Tensor",
42+
"Int4PlainInt32TensorNPU",
3943
"Int4TilePackedTo4dTensor",
4044
"Float8Tensor",
4145
"QuantizeTensorToFloat8Kwargs",

0 commit comments

Comments
 (0)