1414 choose_qparams_affine ,
1515 quantize_affine ,
1616)
17- from torchao .utils import (
18- TorchAOBaseTensor ,
19- )
17+ from torchao .utils import TorchAOBaseTensor , torch_version_at_least
2018
2119__all__ = [
2220 "Int4PlainInt32Tensor" ,
@@ -96,7 +94,10 @@ def from_hp(
9694 elif w .device .type == "npu" :
9795 return _from_hp_npu (cls , w , block_size )
9896 else :
99- raise AssertionError (f"Int4PlainInt32Tensor does not support device '{ w .device .type } ' yet." )
97+ raise NotImplementedError (
98+ f"Int4PlainInt32Tensor does not support device '{ w .device .type } ' yet."
99+ )
100+
100101
101102def _from_hp_xpu (
102103 cls ,
@@ -156,32 +157,34 @@ def _from_hp_xpu(
156157 act_pre_scale = None ,
157158 )
158159
160+
159161def _from_hp_npu (
160162 cls ,
161163 w : torch .Tensor ,
162164 block_size : List [int ],
163165):
166+ # Require PyTorch 2.7.1+ for NPU backend ops and backward compatibility.
167+ assert torch_version_at_least ("2.7.1" ), (
168+ "Need pytorch 2.7.1+ for NPU backend op support."
169+ )
170+
164171 assert w .ndim == 2 and w .device .type == "npu" , (
165172 f"Expecting 2D tensor on NPU, but got: { w .shape } on { w .device .type } "
166173 )
167174 assert len (block_size ) == w .ndim
168175 assert w .dtype in [torch .float16 , torch .bfloat16 ], (
169176 f"Expecting float16 or bfloat16 weight tensor, but got: { w .dtype } "
170177 )
171-
178+
172179 group_size = block_size [1 ]
173180 k_dim = w .shape [- 1 ]
174- assert (
175- group_size >= 32
176- and group_size % 32 == 0
177- and group_size < k_dim
178- ), (
181+ assert group_size >= 32 and group_size % 32 == 0 and group_size < k_dim , (
179182 f"Invalid group_size={ group_size } : "
180183 f"expected to be a multiple of 32, "
181184 f"in range [32, { k_dim - 1 } ] for per-group quantization, "
182185 f"but got group_size={ group_size } (k_dim={ k_dim } )."
183186 )
184-
187+
185188 original_shape = w .shape
186189 mapping_type = MappingType .ASYMMETRIC
187190 target_dtype = torch .int32
@@ -190,7 +193,7 @@ def _from_hp_npu(
190193 eps = 1e-6
191194 scale_dtype = w .dtype
192195 zero_point_dtype = w .dtype
193-
196+
194197 scale , zero_point = choose_qparams_affine (
195198 w ,
196199 mapping_type ,
@@ -202,7 +205,7 @@ def _from_hp_npu(
202205 scale_dtype ,
203206 zero_point_dtype ,
204207 )
205-
208+
206209 int_data = quantize_affine (
207210 w ,
208211 block_size ,
@@ -212,31 +215,31 @@ def _from_hp_npu(
212215 quant_min ,
213216 quant_max ,
214217 )
215-
218+
216219 assert int_data .dtype == torch .int32 , (
217220 "torch.ops.npu.npu_convert_weight_to_int4pack expects `int32` dtype"
218221 )
219222 assert int_data .shape [- 1 ] % 8 == 0 , (
220223 f"torch.ops.npu.npu_convert_weight_to_int4pack expects last dim must be aligned to 8,but got { int_data .shape [- 1 ]} "
221224 )
222-
225+
223226 packed_weight = torch .ops .npu .npu_convert_weight_to_int4pack (
224227 int_data .contiguous (), 0
225228 )
226-
229+
227230 scale = scale .reshape (int_data .shape [0 ], - 1 )
228231 zero_point = zero_point .reshape (int_data .shape [0 ], - 1 )
229-
232+
230233 return Int4PlainInt32Tensor (
231- packed_weight ,
234+ packed_weight . contiguous () ,
232235 scale .transpose (0 , 1 ).contiguous (),
233236 zero_point .transpose (0 , 1 ).contiguous (),
234237 block_size ,
235238 original_shape ,
236239 act_pre_scale = None ,
237240 )
238-
239-
241+
242+
240243implements = Int4PlainInt32Tensor .implements
241244implements_torch_function = Int4PlainInt32Tensor .implements_torch_function
242245
@@ -249,20 +252,22 @@ def _(func, types, args, kwargs):
249252 args [1 ],
250253 args [2 ] if len (args ) > 2 else None ,
251254 )
252-
255+
253256 if input_tensor .device .type == "xpu" :
254257 return _linear_xpu (input_tensor , weight_tensor , bias )
255258 elif input_tensor .device .type == "npu" :
256259 return _linear_npu (input_tensor , weight_tensor , bias )
257260 else :
258- raise AssertionError (f"Int4PlainInt32Tensor does not support device '{ input_tensor .device .type } ' yet." )
261+ raise NotImplementedError (
262+ f"Int4PlainInt32Tensor does not support device '{ input_tensor .device .type } ' yet."
263+ )
259264
260265
261266def _linear_xpu (
262267 input_tensor ,
263268 weight_tensor ,
264269 bias ,
265- ):
270+ ):
266271 assert input_tensor .device .type == "xpu" , (
267272 f"For XPU device only but got: { input_tensor .device } "
268273 )
@@ -306,11 +311,12 @@ def _linear_xpu(
306311 y += bias
307312 return y .to (orig_dtype )
308313
314+
309315def _linear_npu (
310316 input_tensor ,
311317 weight_tensor ,
312318 bias ,
313- ):
319+ ):
314320 assert input_tensor .device .type == "npu" , (
315321 f"For NPU device only but got: { input_tensor .device .type } "
316322 )
@@ -355,24 +361,23 @@ def _linear_npu(
355361
356362 y = torch .ops .npu .npu_weight_quant_batchmatmul (
357363 x = act_mat ,
358- weight = packed_weight .contiguous (). transpose (- 1 , - 2 ),
364+ weight = packed_weight .transpose (- 1 , - 2 ),
359365 antiquant_scale = scale ,
360366 antiquant_offset = zero_point ,
361367 antiquant_group_size = groupsize ,
362368 bias = bias ,
363369 )
364-
370+
365371 # remove out_feature padding
366372 assert weight_tensor .ndim == 2
367373 orig_out_features = weight_tensor .shape [- 2 ]
368374 y = y [:, :orig_out_features ]
369375 y = y .reshape (* orig_act_size [:- 1 ], orig_out_features )
370-
376+
371377 return y .to (orig_dtype )
372378
373379
374380Int4PlainInt32Tensor .__module__ = "torchao.quantization"
375381
376382# Allow a model with Int4PlainInt32Tensor weights to be loaded with `weights_only=True`
377383torch .serialization .add_safe_globals ([Int4PlainInt32Tensor ])
378-
0 commit comments