55# LICENSE file in the root directory of this source tree.
66
77import tempfile
8- import unittest
98
9+ import pytest
1010import torch
11+ from torch .testing ._internal .common_device_type import instantiate_device_type_tests
1112from torch .testing ._internal .common_utils import (
1213 TestCase ,
13- instantiate_parametrized_tests ,
1414 parametrize ,
1515 run_tests ,
1616)
@@ -33,9 +33,19 @@ def get_config(group_size):
3333 )
3434
3535
36- @unittest .skipIf (not torch_version_at_least ("2.8.0" ), "Need pytorch 2.8+" )
37- @unittest .skipIf (not torch .xpu .is_available (), "XPU not available" )
38- class Int4PlainInt32TensorXPU (TestCase ):
36+ class Int4PlainInt32Tensor (TestCase ):
37+ _MIN_VER = {
38+ "xpu" : "2.8.0" ,
39+ "npu" : "2.7.1" ,
40+ }
41+
42+ def setUp (self ):
43+ min_req = type (self )._MIN_VER .get (self .device_type )
44+ if not torch_version_at_least (min_req ):
45+ self .skipTest (
46+ f"{ self .device_type } requires torch >= { min_req } , current { torch .__version__ } "
47+ )
48+
3949 @parametrize (
4050 "sizes" ,
4151 [
@@ -46,90 +56,36 @@ class Int4PlainInt32TensorXPU(TestCase):
4656 )
4757 @parametrize ("dtype" , [torch .bfloat16 , torch .half ])
4858 @parametrize ("group_size" , [32 , 64 , 128 ])
49- def test_linear ( self , sizes , dtype , group_size ):
50- device = "xpu"
59+ @ parametrize ( "thresholds" , [{ "xpu" : 20 , "npu" : 10 }])
60+ def test_linear ( self , device , sizes , dtype , group_size , thresholds ):
5161 M , N , K = sizes
62+ if "npu" in device and group_size == K :
63+ pytest .skip (
64+ f"{ device } does not support group_size equal to K dimension ({ group_size } == { K } )"
65+ )
66+ threshold = thresholds .get (device .split (":" )[0 ])
67+
5268 input = torch .randn (* M , K , dtype = dtype , device = device )
5369 linear = torch .nn .Linear (K , N , dtype = dtype , device = device )
5470 original = linear (input )
5571 quantize_ (linear , get_config (group_size ))
5672 quantized = linear (input )
57- self .assertTrue (compute_error (original , quantized ) > 20 )
73+ self .assertTrue (compute_error (original , quantized ) > threshold )
5874
59- compiled_linear = torch .compile (linear )
60- quantized_and_compiled = compiled_linear (input )
61- self .assertTrue (compute_error (original , quantized_and_compiled ) > 20 )
75+ if "xpu" in device :
76+ compiled_linear = torch .compile (linear )
77+ quantized_and_compiled = compiled_linear (input )
78+ self .assertTrue (compute_error (original , quantized_and_compiled ) > threshold )
6279
6380 @parametrize ("dtype" , [torch .bfloat16 , torch .half ])
64- def test_module_path (self , dtype ):
65- linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = "xpu" )
66- quantize_ (linear , get_config (group_size = 128 ))
67- self .assertEqual (
68- str (type (linear .weight )),
69- "<class 'torchao.quantization.Int4PlainInt32Tensor'>" ,
70- )
71-
72- with tempfile .NamedTemporaryFile () as f :
73- torch .save (linear .state_dict (), f )
74- f .seek (0 )
75- state_dict = torch .load (f )
76- self .assertEqual (
77- str (type (state_dict ["weight" ])),
78- "<class 'torchao.quantization.Int4PlainInt32Tensor'>" ,
79- )
80-
81- def test_activation_prescaling (self ):
82- dtype = torch .bfloat16
83- device = "xpu"
84- input = torch .randn (1 , 128 , dtype = dtype , device = device )
85- linear = torch .nn .Linear (128 , 256 , bias = False , dtype = dtype , device = device )
86- original = linear (input )
87- quantize_ (linear , get_config (128 ))
88- qw = linear .weight
89- assert isinstance (qw , SupportsActivationPreScaling ), (
90- "Expected int4 tensor supports activation prescaling"
91- )
92- assert qw .act_pre_scale is None , "Default `act_pre_scale` is None"
93- _ACT_PRE_SCALE = 2
94- qw .act_pre_scale = _ACT_PRE_SCALE
95- quantized = linear (input )
96-
97- # making sure activation pre scaling is successfully applied to the activation
98- self .assertTrue (compute_error (original * _ACT_PRE_SCALE , quantized ) > 20 )
99-
81+ def test_module_path (self , device , dtype ):
82+ device = self .device_type
83+ K , N , group_size = 128 , 256 , 128
84+ if "npu" in device :
85+ group_size = 64
10086
101- @unittest .skipIf (not torch_version_at_least ("2.7.1" ), "Need pytorch 2.7.1+" )
102- @unittest .skipIf (
103- torch .accelerator .current_accelerator ().type != "npu"
104- or not torch .accelerator .is_available (),
105- "NPU not available" ,
106- )
107- class Int4PlainInt32TensorNPU (TestCase ):
108- @parametrize ("device" , ["npu" ])
109- @parametrize (
110- "sizes" ,
111- [
112- ((128 ,), 256 , 128 ),
113- ((32 , 128 ), 512 , 128 ),
114- ((2 , 32 , 128 ), 256 , 128 ),
115- ],
116- )
117- @parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
118- @parametrize ("group_size" , [32 , 64 ])
119- def test_linear (self , device , sizes , dtype , group_size ):
120- M , N , K = sizes
121- input = torch .randn (* M , K , dtype = dtype , device = device )
12287 linear = torch .nn .Linear (K , N , dtype = dtype , device = device )
123- orig_output = linear (input )
12488 quantize_ (linear , get_config (group_size ))
125- quantized_output = linear (input )
126- self .assertTrue (compute_error (orig_output , quantized_output ) > 10 )
127-
128- @parametrize ("device" , ["npu" ])
129- @parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
130- def test_module_path (self , device , dtype ):
131- linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
132- quantize_ (linear , get_config (group_size = 64 ))
13389 self .assertEqual (
13490 str (type (linear .weight )),
13591 "<class 'torchao.quantization.Int4PlainInt32Tensor'>" ,
@@ -144,13 +100,22 @@ def test_module_path(self, device, dtype):
144100 "<class 'torchao.quantization.Int4PlainInt32Tensor'>" ,
145101 )
146102
147- @parametrize ("device" , ["npu" ])
148103 @parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
149- def test_activation_prescaling (self , device , dtype ):
150- input = torch .randn (1 , 128 , dtype = dtype , device = device )
151- linear = torch .nn .Linear (128 , 256 , bias = False , dtype = dtype , device = device )
104+ @parametrize ("thresholds" , [{"xpu" : 20 , "npu" : 10 }])
105+ def test_activation_prescaling (self , device , dtype , thresholds ):
106+ device = self .device_type
107+ if "xpu" in device and dtype == torch .float16 :
108+ pytest .skip (f"{ device } test_activation_prescaling don't test { dtype } " )
109+
110+ threshold = thresholds .get (device .split (":" )[0 ])
111+ K , N , group_size = 128 , 256 , 128
112+ if "npu" in device :
113+ group_size = 64
114+
115+ input = torch .randn (1 , K , dtype = dtype , device = device )
116+ linear = torch .nn .Linear (K , N , bias = False , dtype = dtype , device = device )
152117 original = linear (input )
153- quantize_ (linear , get_config (64 ))
118+ quantize_ (linear , get_config (group_size ))
154119 qw = linear .weight
155120 assert isinstance (qw , SupportsActivationPreScaling ), (
156121 "Expected int4 tensor supports activation prescaling"
@@ -161,11 +126,13 @@ def test_activation_prescaling(self, device, dtype):
161126 quantized = linear (input )
162127
163128 # making sure activation pre scaling is successfully applied to the activation
164- self .assertTrue (compute_error (original * _ACT_PRE_SCALE , quantized ) > 10 )
129+ self .assertTrue (compute_error (original * _ACT_PRE_SCALE , quantized ) > threshold )
130+
165131
132+ instantiate_device_type_tests (
133+ Int4PlainInt32Tensor , globals (), only_for = ("xpu" , "npu" ), allow_xpu = True
134+ )
166135
167- instantiate_parametrized_tests (Int4PlainInt32TensorXPU )
168- instantiate_parametrized_tests (Int4PlainInt32TensorNPU )
169136
170137if __name__ == "__main__" :
171138 run_tests ()
0 commit comments