@@ -51,6 +51,10 @@ def forward(self, x):
5151 devices .append ("cuda" )
5252
5353
54+ if torch .xpu .is_available ():
55+ devices .append ("xpu" )
56+
57+
5458class TestAWQ (TestCase ):
5559 def test_awq_config (self ):
5660 base_config = Int4WeightOnlyConfig ()
@@ -79,6 +83,10 @@ def test_awq_functionality(self, device):
7983 # baseline quantization
8084 if device == "cuda" :
8185 base_config = Int4WeightOnlyConfig (group_size = group_size )
86+ elif device == "xpu" :
87+ base_config = Int4WeightOnlyConfig (
88+ group_size = group_size , int4_packing_format = "plain_int32"
89+ )
8290 elif device == "cpu" :
8391 base_config = Int4WeightOnlyConfig (
8492 group_size = group_size , int4_packing_format = "opaque"
@@ -137,6 +145,10 @@ def test_awq_loading(self, device):
137145 # calibrate
138146 if device == "cuda" :
139147 base_config = Int4WeightOnlyConfig (group_size = group_size )
148+ elif device == "xpu" :
149+ base_config = Int4WeightOnlyConfig (
150+ group_size = group_size , int4_packing_format = "plain_int32"
151+ )
140152 elif device == "cpu" :
141153 base_config = Int4WeightOnlyConfig (
142154 group_size = group_size , int4_packing_format = "opaque"
@@ -198,6 +210,10 @@ def test_awq_loading_vllm(self, device):
198210 # calibrate
199211 if device == "cuda" :
200212 base_config = Int4WeightOnlyConfig (group_size = group_size )
213+ elif device == "xpu" :
214+ base_config = Int4WeightOnlyConfig (
215+ group_size = group_size , int4_packing_format = "plain_int32"
216+ )
201217 elif device == "cpu" :
202218 base_config = Int4WeightOnlyConfig (
203219 group_size = group_size , int4_packing_format = "opaque"
0 commit comments