Skip to content

Commit ca8f056

Browse files
committed
ruff format cleanup, replace error types, add torch version check
1 parent ea2aa7a commit ca8f056

File tree

2 files changed

+36
-32
lines changed

2 files changed

+36
-32
lines changed

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def test_activation_prescaling(self):
105105
"NPU not available",
106106
)
107107
class Int4PlainInt32TensorNPU(TestCase):
108-
109108
@parametrize("device", ["npu"])
110109
@parametrize(
111110
"sizes",
@@ -153,9 +152,9 @@ def test_activation_prescaling(self, device, dtype):
153152
original = linear(input)
154153
quantize_(linear, get_config(64))
155154
qw = linear.weight
156-
assert isinstance(
157-
qw, SupportsActivationPreScaling
158-
), "Expected int4 tensor supports activation prescaling"
155+
assert isinstance(qw, SupportsActivationPreScaling), (
156+
"Expected int4 tensor supports activation prescaling"
157+
)
159158
assert qw.act_pre_scale is None, "Default `act_pre_scale` is None"
160159
_ACT_PRE_SCALE = 2
161160
qw.act_pre_scale = _ACT_PRE_SCALE

torchao/quantization/quantize_/workflows/int4/int4_plain_int32_tensor.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,7 @@
1414
choose_qparams_affine,
1515
quantize_affine,
1616
)
17-
from torchao.utils import (
18-
TorchAOBaseTensor,
19-
)
17+
from torchao.utils import TorchAOBaseTensor, torch_version_at_least
2018

2119
__all__ = [
2220
"Int4PlainInt32Tensor",
@@ -96,7 +94,10 @@ def from_hp(
9694
elif w.device.type == "npu":
9795
return _from_hp_npu(cls, w, block_size)
9896
else:
99-
raise AssertionError(f"Int4PlainInt32Tensor does not support device '{w.device.type}' yet.")
97+
raise NotImplementedError(
98+
f"Int4PlainInt32Tensor does not support device '{w.device.type}' yet."
99+
)
100+
100101

101102
def _from_hp_xpu(
102103
cls,
@@ -156,32 +157,34 @@ def _from_hp_xpu(
156157
act_pre_scale=None,
157158
)
158159

160+
159161
def _from_hp_npu(
160162
cls,
161163
w: torch.Tensor,
162164
block_size: List[int],
163165
):
166+
# Require PyTorch 2.7.1+ for NPU backend ops and backward compatibility.
167+
assert torch_version_at_least("2.7.1"), (
168+
"Need pytorch 2.7.1+ for NPU backend op support."
169+
)
170+
164171
assert w.ndim == 2 and w.device.type == "npu", (
165172
f"Expecting 2D tensor on NPU, but got: {w.shape} on {w.device.type}"
166173
)
167174
assert len(block_size) == w.ndim
168175
assert w.dtype in [torch.float16, torch.bfloat16], (
169176
f"Expecting float16 or bfloat16 weight tensor, but got: {w.dtype}"
170177
)
171-
178+
172179
group_size = block_size[1]
173180
k_dim = w.shape[-1]
174-
assert (
175-
group_size >= 32
176-
and group_size % 32 == 0
177-
and group_size < k_dim
178-
), (
181+
assert group_size >= 32 and group_size % 32 == 0 and group_size < k_dim, (
179182
f"Invalid group_size={group_size}: "
180183
f"expected to be a multiple of 32, "
181184
f"in range [32, {k_dim - 1}] for per-group quantization, "
182185
f"but got group_size={group_size} (k_dim={k_dim})."
183186
)
184-
187+
185188
original_shape = w.shape
186189
mapping_type = MappingType.ASYMMETRIC
187190
target_dtype = torch.int32
@@ -190,7 +193,7 @@ def _from_hp_npu(
190193
eps = 1e-6
191194
scale_dtype = w.dtype
192195
zero_point_dtype = w.dtype
193-
196+
194197
scale, zero_point = choose_qparams_affine(
195198
w,
196199
mapping_type,
@@ -202,7 +205,7 @@ def _from_hp_npu(
202205
scale_dtype,
203206
zero_point_dtype,
204207
)
205-
208+
206209
int_data = quantize_affine(
207210
w,
208211
block_size,
@@ -212,31 +215,31 @@ def _from_hp_npu(
212215
quant_min,
213216
quant_max,
214217
)
215-
218+
216219
assert int_data.dtype == torch.int32, (
217220
"torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype"
218221
)
219222
assert int_data.shape[-1] % 8 == 0, (
220223
f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got {int_data.shape[-1]}"
221224
)
222-
225+
223226
packed_weight = torch.ops.npu.npu_convert_weight_to_int4pack(
224227
int_data.contiguous(), 0
225228
)
226-
229+
227230
scale = scale.reshape(int_data.shape[0], -1)
228231
zero_point = zero_point.reshape(int_data.shape[0], -1)
229-
232+
230233
return Int4PlainInt32Tensor(
231-
packed_weight,
234+
packed_weight.contiguous(),
232235
scale.transpose(0, 1).contiguous(),
233236
zero_point.transpose(0, 1).contiguous(),
234237
block_size,
235238
original_shape,
236239
act_pre_scale=None,
237240
)
238-
239-
241+
242+
240243
implements = Int4PlainInt32Tensor.implements
241244
implements_torch_function = Int4PlainInt32Tensor.implements_torch_function
242245

@@ -249,20 +252,22 @@ def _(func, types, args, kwargs):
249252
args[1],
250253
args[2] if len(args) > 2 else None,
251254
)
252-
255+
253256
if input_tensor.device.type == "xpu":
254257
return _linear_xpu(input_tensor, weight_tensor, bias)
255258
elif input_tensor.device.type == "npu":
256259
return _linear_npu(input_tensor, weight_tensor, bias)
257260
else:
258-
raise AssertionError(f"Int4PlainInt32Tensor does not support device '{input_tensor.device.type}' yet.")
261+
raise NotImplementedError(
262+
f"Int4PlainInt32Tensor does not support device '{input_tensor.device.type}' yet."
263+
)
259264

260265

261266
def _linear_xpu(
262267
input_tensor,
263268
weight_tensor,
264269
bias,
265-
):
270+
):
266271
assert input_tensor.device.type == "xpu", (
267272
f"For XPU device only but got: {input_tensor.device}"
268273
)
@@ -306,11 +311,12 @@ def _linear_xpu(
306311
y += bias
307312
return y.to(orig_dtype)
308313

314+
309315
def _linear_npu(
310316
input_tensor,
311317
weight_tensor,
312318
bias,
313-
):
319+
):
314320
assert input_tensor.device.type == "npu", (
315321
f"For NPU device only but got: {input_tensor.device.type}"
316322
)
@@ -355,24 +361,23 @@ def _linear_npu(
355361

356362
y = torch.ops.npu.npu_weight_quant_batchmatmul(
357363
x=act_mat,
358-
weight=packed_weight.contiguous().transpose(-1, -2),
364+
weight=packed_weight.transpose(-1, -2),
359365
antiquant_scale=scale,
360366
antiquant_offset=zero_point,
361367
antiquant_group_size=groupsize,
362368
bias=bias,
363369
)
364-
370+
365371
# remove out_feature padding
366372
assert weight_tensor.ndim == 2
367373
orig_out_features = weight_tensor.shape[-2]
368374
y = y[:, :orig_out_features]
369375
y = y.reshape(*orig_act_size[:-1], orig_out_features)
370-
376+
371377
return y.to(orig_dtype)
372378

373379

374380
Int4PlainInt32Tensor.__module__ = "torchao.quantization"
375381

376382
# Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True`
377383
torch.serialization.add_safe_globals([Int4PlainInt32Tensor])
378-

0 commit comments

Comments
 (0)