Skip to content

Commit ea2aa7a

Browse files
committed
add: merge NPU(Ascend) backend logic in Int4PlainInt32Tensor subclass
1 parent 498f052 commit ea2aa7a

File tree

7 files changed

+296
-409
lines changed

7 files changed

+296
-409
lines changed

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_config(group_size):
3535

3636
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
3737
@unittest.skipIf(not torch.xpu.is_available(), "XPU not available")
38-
class Int4PlainInt32Tensor(TestCase):
38+
class Int4PlainInt32TensorXPU(TestCase):
3939
@parametrize(
4040
"sizes",
4141
[
@@ -98,8 +98,75 @@ def test_activation_prescaling(self):
9898
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 20)
9999

100100

101-
instantiate_parametrized_tests(Int4PlainInt32Tensor)
101+
@unittest.skipIf(not torch_version_at_least("2.7.1"), "Need pytorch 2.7.1+")
102+
@unittest.skipIf(
103+
torch.accelerator.current_accelerator().type != "npu"
104+
or not torch.accelerator.is_available(),
105+
"NPU not available",
106+
)
107+
class Int4PlainInt32TensorNPU(TestCase):
108+
109+
@parametrize("device", ["npu"])
110+
@parametrize(
111+
"sizes",
112+
[
113+
((128,), 256, 128),
114+
((32, 128), 512, 128),
115+
((2, 32, 128), 256, 128),
116+
],
117+
)
118+
@parametrize("dtype", [torch.float16, torch.bfloat16])
119+
@parametrize("group_size", [32, 64])
120+
def test_linear(self, device, sizes, dtype, group_size):
121+
M, N, K = sizes
122+
input = torch.randn(*M, K, dtype=dtype, device=device)
123+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
124+
orig_output = linear(input)
125+
quantize_(linear, get_config(group_size))
126+
quantized_output = linear(input)
127+
self.assertTrue(compute_error(orig_output, quantized_output) > 10)
128+
129+
@parametrize("device", ["npu"])
130+
@parametrize("dtype", [torch.float16, torch.bfloat16])
131+
def test_module_path(self, device, dtype):
132+
linear = torch.nn.Linear(128, 256, dtype=dtype, device=device)
133+
quantize_(linear, get_config(group_size=64))
134+
self.assertEqual(
135+
str(type(linear.weight)),
136+
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
137+
)
138+
139+
with tempfile.NamedTemporaryFile() as f:
140+
torch.save(linear.state_dict(), f)
141+
f.seek(0)
142+
state_dict = torch.load(f)
143+
self.assertEqual(
144+
str(type(state_dict["weight"])),
145+
"<class 'torchao.quantization.Int4PlainInt32Tensor'>",
146+
)
147+
148+
@parametrize("device", ["npu"])
149+
@parametrize("dtype", [torch.float16, torch.bfloat16])
150+
def test_activation_prescaling(self, device, dtype):
151+
input = torch.randn(1, 128, dtype=dtype, device=device)
152+
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
153+
original = linear(input)
154+
quantize_(linear, get_config(64))
155+
qw = linear.weight
156+
assert isinstance(
157+
qw, SupportsActivationPreScaling
158+
), "Expected int4 tensor supports activation prescaling"
159+
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
160+
_ACT_PRE_SCALE = 2
161+
qw.act_pre_scale = _ACT_PRE_SCALE
162+
quantized = linear(input)
163+
164+
# making sure activation pre scaling is successfully applied to the activation
165+
self.assertTrue(compute_error(original * _ACT_PRE_SCALE, quantized) > 10)
166+
102167

168+
instantiate_parametrized_tests(Int4PlainInt32TensorXPU)
169+
instantiate_parametrized_tests(Int4PlainInt32TensorNPU)
103170

104171
if __name__ == "__main__":
105172
run_tests()

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor_npu.py

Lines changed: 0 additions & 107 deletions
This file was deleted.

torchao/quantization/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@
9494
Int4MarlinSparseTensor,
9595
Int4OpaqueTensor,
9696
Int4PlainInt32Tensor,
97-
Int4PlainInt32TensorNPU,
9897
Int4PreshuffledTensor,
9998
Int4Tensor,
10099
Int4TilePackedTo4dTensor,

torchao/quantization/quant_api.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,6 @@
7878
Int4OpaqueTensor,
7979
Int4PackingFormat,
8080
Int4PlainInt32Tensor,
81-
Int4PlainInt32TensorNPU,
8281
Int4PreshuffledTensor,
8382
Int4Tensor,
8483
Int4TilePackedTo4dTensor,
@@ -1195,16 +1194,10 @@ def _int4_weight_only_quantize_tensor(weight, config):
11951194
)
11961195
return new_weight
11971196
elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
1198-
if weight.device.type == "npu":
1199-
new_weight = Int4PlainInt32TensorNPU.from_hp(
1200-
weight,
1201-
block_size,
1202-
)
1203-
else:
1204-
new_weight = Int4PlainInt32Tensor.from_hp(
1205-
weight,
1206-
block_size,
1207-
)
1197+
new_weight = Int4PlainInt32Tensor.from_hp(
1198+
weight,
1199+
block_size,
1200+
)
12081201
return new_weight
12091202
elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE:
12101203
new_weight = Int4MarlinSparseTensor.from_hp(

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313
from .int4.int4_plain_int32_tensor import (
1414
Int4PlainInt32Tensor,
1515
)
16-
from .int4.int4_plain_int32_tensor_npu import (
17-
Int4PlainInt32TensorNPU,
18-
)
1916
from .int4.int4_preshuffled_tensor import (
2017
Int4PreshuffledTensor,
2118
)
@@ -39,7 +36,6 @@
3936
"Int4PreshuffledTensor",
4037
"Int4MarlinSparseTensor",
4138
"Int4PlainInt32Tensor",
42-
"Int4PlainInt32TensorNPU",
4339
"Int4TilePackedTo4dTensor",
4440
"Float8Tensor",
4541
"QuantizeTensorToFloat8Kwargs",

0 commit comments

Comments
 (0)