44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ import pytest
78import tempfile
89import unittest
910
@@ -33,103 +34,88 @@ def get_config(group_size):
3334 )
3435
3536
36- @unittest .skipIf (not torch_version_at_least ("2.8.0" ), "Need pytorch 2.8+" )
37- @unittest .skipIf (not torch .xpu .is_available (), "XPU not available" )
38- class Int4PlainInt32TensorXPU (TestCase ):
39- @parametrize (
40- "sizes" ,
41- [
42- ((128 ,), 256 , 128 ),
43- ((32 , 128 ), 512 , 128 ),
44- ((2 , 32 , 128 ), 256 , 12 ),
45- ],
46- )
47- @parametrize ("dtype" , [torch .bfloat16 , torch .half ])
48- @parametrize ("group_size" , [32 , 64 , 128 ])
49- def test_linear (self , sizes , dtype , group_size ):
50- device = "xpu"
51- M , N , K = sizes
52- input = torch .randn (* M , K , dtype = dtype , device = device )
53- linear = torch .nn .Linear (K , N , dtype = dtype , device = device )
54- original = linear (input )
55- quantize_ (linear , get_config (group_size ))
56- quantized = linear (input )
57- self .assertTrue (compute_error (original , quantized ) > 20 )
37+ _MIN_VER = {
38+ "xpu" : "2.8.0" ,
39+ "npu" : "2.7.1" ,
40+ }
41+ THRESHOLD = {"xpu" : 20 , "npu" : 10 }
5842
59- compiled_linear = torch .compile (linear )
60- quantized_and_compiled = compiled_linear (input )
61- self .assertTrue (compute_error (original , quantized_and_compiled ) > 20 )
43+ ALL_DEVICES = ("xpu" , "npu" )
6244
63- @parametrize ("dtype" , [torch .bfloat16 , torch .half ])
64- def test_module_path (self , dtype ):
65- linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = "xpu" )
66- quantize_ (linear , get_config (group_size = 128 ))
67- self .assertEqual (
68- str (type (linear .weight )),
69- "<class 'torchao.quantization.Int4PlainInt32Tensor'>" ,
70- )
7145
72- with tempfile .NamedTemporaryFile () as f :
73- torch .save (linear .state_dict (), f )
74- f .seek (0 )
75- state_dict = torch .load (f )
76- self .assertEqual (
77- str (type (state_dict ["weight" ])),
78- "<class 'torchao.quantization.Int4PlainInt32Tensor'>" ,
79- )
46+ def _get_available_devices () -> tuple [list [str ], list [str ]]:
47+ available_devices = []
48+ messages = []
49+ for name in ALL_DEVICES :
50+ mod = getattr (torch , name , None )
51+ if mod is None :
52+ messages .append (f"{ name } : not found in torch" )
53+ continue
54+ avail = mod .is_available ()
55+ status = []
56+ status .append (f"available={ avail } " )
57+ status .append (f"min_version_req={ _MIN_VER [name ]} " )
58+ status .append (f"torch_version={ torch .__version__ } " )
59+ if avail and torch_version_at_least (_MIN_VER [name ]):
60+ available_devices .append (name )
61+ status .append ("OK" )
62+ else :
63+ status .append ("FAIL" )
64+ messages .append (f"{ name } : " + ", " .join (status ))
8065
81- def test_activation_prescaling (self ):
82- dtype = torch .bfloat16
83- device = "xpu"
84- input = torch .randn (1 , 128 , dtype = dtype , device = device )
85- linear = torch .nn .Linear (128 , 256 , bias = False , dtype = dtype , device = device )
86- original = linear (input )
87- quantize_ (linear , get_config (128 ))
88- qw = linear .weight
89- assert isinstance (qw , SupportsActivationPreScaling ), (
90- "Expected int4 tensor supports activation prescaling"
91- )
92- assert qw .act_pre_scale is None , "Default `act_pre_scale` is None"
93- _ACT_PRE_SCALE = 2
94- qw .act_pre_scale = _ACT_PRE_SCALE
95- quantized = linear (input )
66+ return available_devices , messages
9667
97- # making sure activation pre scaling is successfully applied to the activation
98- self .assertTrue (compute_error (original * _ACT_PRE_SCALE , quantized ) > 20 )
68+
69+ AVAILABLE_DEVICES , MESSAGES = _get_available_devices ()
70+ print ("\n Device Status:" )
71+ for msg in MESSAGES :
72+ print (" " , msg )
9973
10074
101- @unittest .skipIf (not torch_version_at_least ("2.7.1" ), "Need pytorch 2.7.1+" )
10275@unittest .skipIf (
103- torch .accelerator .current_accelerator ().type != "npu"
104- or not torch .accelerator .is_available (),
105- "NPU not available" ,
76+ not AVAILABLE_DEVICES , f"No available devices: { ', ' .join (ALL_DEVICES )} "
10677)
107- class Int4PlainInt32TensorNPU (TestCase ):
108- @parametrize ("device" , [ "npu" ] )
78+ class Int4PlainInt32Tensor (TestCase ):
79+ @parametrize ("device" , AVAILABLE_DEVICES )
10980 @parametrize (
11081 "sizes" ,
11182 [
11283 ((128 ,), 256 , 128 ),
11384 ((32 , 128 ), 512 , 128 ),
114- ((2 , 32 , 128 ), 256 , 128 ),
85+ ((2 , 32 , 128 ), 256 , 12 ),
11586 ],
11687 )
117- @parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
118- @parametrize ("group_size" , [32 , 64 ])
88+ @parametrize ("dtype" , [torch .bfloat16 , torch .half ])
89+ @parametrize ("group_size" , [32 , 64 , 128 ])
11990 def test_linear (self , device , sizes , dtype , group_size ):
12091 M , N , K = sizes
92+ if device == "npu" and group_size == K :
93+ pytest .skip (
94+ f"{ device } does not support group_size equal to K dimension ({ group_size } == { K } )"
95+ )
96+ threshold = THRESHOLD .get (device )
97+
12198 input = torch .randn (* M , K , dtype = dtype , device = device )
12299 linear = torch .nn .Linear (K , N , dtype = dtype , device = device )
123- orig_output = linear (input )
100+ original = linear (input )
124101 quantize_ (linear , get_config (group_size ))
125- quantized_output = linear (input )
126- self .assertTrue (compute_error (orig_output , quantized_output ) > 10 )
102+ quantized = linear (input )
103+ self .assertTrue (compute_error (original , quantized ) > threshold )
127104
128- @parametrize ("device" , ["npu" ])
129- @parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
105+ if device == "xpu" :
106+ compiled_linear = torch .compile (linear )
107+ quantized_and_compiled = compiled_linear (input )
108+ self .assertTrue (compute_error (original , quantized_and_compiled ) > threshold )
109+
110+ @parametrize ("device" , AVAILABLE_DEVICES )
111+ @parametrize ("dtype" , [torch .bfloat16 , torch .half ])
130112 def test_module_path (self , device , dtype ):
131- linear = torch .nn .Linear (128 , 256 , dtype = dtype , device = device )
132- quantize_ (linear , get_config (group_size = 64 ))
113+ K , N , group_size = 128 , 256 , 128
114+ if device == "npu" :
115+ group_size = 64
116+
117+ linear = torch .nn .Linear (K , N , dtype = dtype , device = device )
118+ quantize_ (linear , get_config (group_size ))
133119 self .assertEqual (
134120 str (type (linear .weight )),
135121 "<class 'torchao.quantization.Int4PlainInt32Tensor'>" ,
@@ -144,13 +130,21 @@ def test_module_path(self, device, dtype):
144130 "<class 'torchao.quantization.Int4PlainInt32Tensor'>" ,
145131 )
146132
147- @parametrize ("device" , [ "npu" ] )
133+ @parametrize ("device" , AVAILABLE_DEVICES )
148134 @parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
149135 def test_activation_prescaling (self , device , dtype ):
136+ if device == "xpu" and dtype == torch .float16 :
137+ pytest .skip (f"{ device } test_activation_prescaling don't test { dtype } " )
138+
139+ threshold = THRESHOLD .get (device )
140+ K , N , group_size = 128 , 256 , 128
141+ if device == "npu" :
142+ group_size = 64
143+
150144 input = torch .randn (1 , 128 , dtype = dtype , device = device )
151145 linear = torch .nn .Linear (128 , 256 , bias = False , dtype = dtype , device = device )
152146 original = linear (input )
153- quantize_ (linear , get_config (64 ))
147+ quantize_ (linear , get_config (group_size ))
154148 qw = linear .weight
155149 assert isinstance (qw , SupportsActivationPreScaling ), (
156150 "Expected int4 tensor supports activation prescaling"
@@ -161,11 +155,11 @@ def test_activation_prescaling(self, device, dtype):
161155 quantized = linear (input )
162156
163157 # making sure activation pre scaling is successfully applied to the activation
164- self .assertTrue (compute_error (original * _ACT_PRE_SCALE , quantized ) > 10 )
158+ self .assertTrue (compute_error (original * _ACT_PRE_SCALE , quantized ) > threshold )
159+
165160
161+ instantiate_parametrized_tests (Int4PlainInt32Tensor )
166162
167- instantiate_parametrized_tests (Int4PlainInt32TensorXPU )
168- instantiate_parametrized_tests (Int4PlainInt32TensorNPU )
169163
170164if __name__ == "__main__" :
171165 run_tests ()
0 commit comments