@@ -33,17 +33,14 @@ def get_config(group_size):
3333 )
3434
3535
36- _ALL_DEVICES = ("xpu" , "npu" )
37- _MIN_VER = {
38- "xpu" : "2.8.0" ,
39- "npu" : "2.7.1" ,
40- }
41- _THRESHOLD = {"xpu" : 20 , "npu" : 10 }
42-
43-
4436class Int4PlainInt32Tensor (TestCase ):
37+ _MIN_VER = {
38+ "xpu" : "2.8.0" ,
39+ "npu" : "2.7.1" ,
40+ }
41+
4542 def setUp (self ):
46- min_req = _MIN_VER .get (self .device_type )
43+ min_req = type ( self ). _MIN_VER .get (self .device_type )
4744 if not torch_version_at_least (min_req ):
4845 self .skipTest (
4946 f"{ self .device_type } requires torch >= { min_req } , current { torch .__version__ } "
@@ -59,14 +56,14 @@ def setUp(self):
5956 )
6057 @parametrize ("dtype" , [torch .bfloat16 , torch .half ])
6158 @parametrize ("group_size" , [32 , 64 , 128 ])
62- def test_linear ( self , sizes , dtype , group_size ):
63- device = self . device_type
59+ @ parametrize ( "thresholds" , [{ "xpu" : 20 , "npu" : 10 }])
60+ def test_linear ( self , device , sizes , dtype , group_size , thresholds ):
6461 M , N , K = sizes
65- if device == "npu" and group_size == K :
62+ if "npu" in device and group_size == K :
6663 pytest .skip (
6764 f"{ device } does not support group_size equal to K dimension ({ group_size } == { K } )"
6865 )
69- threshold = _THRESHOLD .get (device )
66+ threshold = thresholds .get (device . split ( ":" )[ 0 ] )
7067
7168 input = torch .randn (* M , K , dtype = dtype , device = device )
7269 linear = torch .nn .Linear (K , N , dtype = dtype , device = device )
@@ -75,16 +72,16 @@ def test_linear(self, sizes, dtype, group_size):
7572 quantized = linear (input )
7673 self .assertTrue (compute_error (original , quantized ) > threshold )
7774
78- if device == "xpu" :
75+ if "xpu" in device :
7976 compiled_linear = torch .compile (linear )
8077 quantized_and_compiled = compiled_linear (input )
8178 self .assertTrue (compute_error (original , quantized_and_compiled ) > threshold )
8279
8380 @parametrize ("dtype" , [torch .bfloat16 , torch .half ])
84- def test_module_path (self , dtype ):
81+ def test_module_path (self , device , dtype ):
8582 device = self .device_type
8683 K , N , group_size = 128 , 256 , 128
87- if device == "npu" :
84+ if "npu" in device :
8885 group_size = 64
8986
9087 linear = torch .nn .Linear (K , N , dtype = dtype , device = device )
@@ -104,14 +101,15 @@ def test_module_path(self, dtype):
104101 )
105102
106103 @parametrize ("dtype" , [torch .float16 , torch .bfloat16 ])
107- def test_activation_prescaling (self , dtype ):
104+ @parametrize ("thresholds" , [{"xpu" : 20 , "npu" : 10 }])
105+ def test_activation_prescaling (self , device , dtype , thresholds ):
108106 device = self .device_type
109- if device == "xpu" and dtype == torch .float16 :
107+ if "xpu" in device and dtype == torch .float16 :
110108 pytest .skip (f"{ device } test_activation_prescaling don't test { dtype } " )
111109
112- threshold = _THRESHOLD .get (device )
110+ threshold = thresholds .get (device . split ( ":" )[ 0 ] )
113111 K , N , group_size = 128 , 256 , 128
114- if device == "npu" :
112+ if "npu" in device :
115113 group_size = 64
116114
117115 input = torch .randn (1 , K , dtype = dtype , device = device )
@@ -132,7 +130,7 @@ def test_activation_prescaling(self, dtype):
132130
133131
134132instantiate_device_type_tests (
135- Int4PlainInt32Tensor , globals (), only_for = _ALL_DEVICES , allow_xpu = True
133+ Int4PlainInt32Tensor , globals (), only_for = ( "xpu" , "npu" ) , allow_xpu = True
136134)
137135
138136
0 commit comments