44# This source code is licensed under the BSD 3-Clause license found in the
55# LICENSE file in the root directory of this source tree.
66
7- import pytest
87import tempfile
9- import unittest
108
9+ import pytest
1110import torch
11+ from torch .testing ._internal .common_device_type import instantiate_device_type_tests
1212from torch .testing ._internal .common_utils import (
1313 TestCase ,
14- instantiate_parametrized_tests ,
1514 parametrize ,
1615 run_tests ,
1716)
@@ -34,49 +33,22 @@ def get_config(group_size):
3433 )
3534
3635
36+ _ALL_DEVICES = ("xpu" , "npu" )
3737_MIN_VER = {
3838 "xpu" : "2.8.0" ,
3939 "npu" : "2.7.1" ,
4040}
41- THRESHOLD = {"xpu" : 20 , "npu" : 10 }
42-
43- ALL_DEVICES = ("xpu" , "npu" )
44-
45-
46- def _get_available_devices () -> tuple [list [str ], list [str ]]:
47- available_devices = []
48- messages = []
49- for name in ALL_DEVICES :
50- mod = getattr (torch , name , None )
51- if mod is None :
52- messages .append (f"{ name } : not found in torch" )
53- continue
54- avail = mod .is_available ()
55- status = []
56- status .append (f"available={ avail } " )
57- status .append (f"min_version_req={ _MIN_VER [name ]} " )
58- status .append (f"torch_version={ torch .__version__ } " )
59- if avail and torch_version_at_least (_MIN_VER [name ]):
60- available_devices .append (name )
61- status .append ("OK" )
62- else :
63- status .append ("FAIL" )
64- messages .append (f"{ name } : " + ", " .join (status ))
65-
66- return available_devices , messages
67-
68-
69- AVAILABLE_DEVICES , MESSAGES = _get_available_devices ()
70- print ("\n Device Status:" )
71- for msg in MESSAGES :
72- print (" " , msg )
73-
74-
75- @unittest .skipIf (
76- not AVAILABLE_DEVICES , f"No available devices: { ', ' .join (ALL_DEVICES )} "
77- )
41+ _THRESHOLD = {"xpu" : 20 , "npu" : 10 }
42+
43+
7844class Int4PlainInt32Tensor (TestCase ):
79- @parametrize ("device" , AVAILABLE_DEVICES )
45+ def setUp (self ):
46+ min_req = _MIN_VER .get (self .device_type )
47+ if not torch_version_at_least (min_req ):
48+ self .skipTest (
49+ f"{ self .device_type } requires torch >= { min_req } , current { torch .__version__ } "
50+ )
51+
8052 @parametrize (
8153 "sizes" ,
8254 [
@@ -87,13 +59,14 @@ class Int4PlainInt32Tensor(TestCase):
8759 )
8860 @parametrize ("dtype" , [torch .bfloat16 , torch .half ])
8961 @parametrize ("group_size" , [32 , 64 , 128 ])
90- def test_linear (self , device , sizes , dtype , group_size ):
62+ def test_linear (self , sizes , dtype , group_size ):
63+ device = self .device_type
9164 M , N , K = sizes
9265 if device == "npu" and group_size == K :
9366 pytest .skip (
9467 f"{ device } does not support group_size equal to K dimension ({ group_size } == { K } )"
9568 )
96- threshold = THRESHOLD .get (device )
69+ threshold = _THRESHOLD .get (device )
9770
9871 input = torch .randn (* M , K , dtype = dtype , device = device )
9972 linear = torch .nn .Linear (K , N , dtype = dtype , device = device )
@@ -107,9 +80,9 @@ def test_linear(self, device, sizes, dtype, group_size):
10780 quantized_and_compiled = compiled_linear (input )
10881 self .assertTrue (compute_error (original , quantized_and_compiled ) > threshold )
10982
110- @parametrize ("device" , AVAILABLE_DEVICES )
11183 @parametrize ("dtype" , [torch .bfloat16 , torch .half ])
112- def test_module_path (self , device , dtype ):
84+ def test_module_path (self , dtype ):
85+ device = self .device_type
11386 K , N , group_size = 128 , 256 , 128
11487 if device == "npu" :
11588 group_size = 64
@@ -130,19 +103,19 @@ def test_module_path(self, device, dtype):
130103 "<class 'torchao.quantization.Int4PlainInt32Tensor'>" ,
131104 )
132105
133- @parametrize ("device" , AVAILABLE_DEVICES )
134106 @parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
135- def test_activation_prescaling (self , device , dtype ):
107+ def test_activation_prescaling (self , dtype ):
108+ device = self .device_type
136109 if device == "xpu" and dtype == torch .float16 :
137110 pytest .skip (f"{ device } test_activation_prescaling don't test { dtype } " )
138111
139- threshold = THRESHOLD .get (device )
112+ threshold = _THRESHOLD .get (device )
140113 K , N , group_size = 128 , 256 , 128
141114 if device == "npu" :
142115 group_size = 64
143116
144- input = torch .randn (1 , 128 , dtype = dtype , device = device )
145- linear = torch .nn .Linear (128 , 256 , bias = False , dtype = dtype , device = device )
117+ input = torch .randn (1 , K , dtype = dtype , device = device )
118+ linear = torch .nn .Linear (K , N , bias = False , dtype = dtype , device = device )
146119 original = linear (input )
147120 quantize_ (linear , get_config (group_size ))
148121 qw = linear .weight
@@ -158,7 +131,9 @@ def test_activation_prescaling(self, device, dtype):
158131 self .assertTrue (compute_error (original * _ACT_PRE_SCALE , quantized ) > threshold )
159132
160133
161- instantiate_parametrized_tests (Int4PlainInt32Tensor )
134+ instantiate_device_type_tests (
135+ Int4PlainInt32Tensor , globals (), only_for = _ALL_DEVICES , allow_xpu = True
136+ )
162137
163138
164139if __name__ == "__main__" :
0 commit comments