@@ -653,13 +653,13 @@ def test_qat_4w_primitives(self):
653653
654654 self ._assert_close_4w (qat_out , ptq_out )
655655
656- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
656+ @unittest .skipIf (_DEVICE is None , "skipping when GPU is not available" )
657657 def test_qat_4w_linear (self ):
658658 from torchao .quantization .GPTQ import WeightOnlyInt4Linear
659659 from torchao .quantization .qat .linear import Int4WeightOnlyQATLinear
660660
661661 group_size = 128
662- device = torch .device ("cuda" )
662+ device = torch .device (_DEVICE )
663663 dtype = torch .bfloat16
664664 torch .manual_seed (self .SEED )
665665 qat_linear = Int4WeightOnlyQATLinear (
@@ -694,7 +694,11 @@ def test_qat_4w_quantizer_gradients(self):
694694 quantizer = Int4WeightOnlyQATQuantizer (groupsize = 32 , inner_k_tiles = 8 )
695695 self ._test_qat_quantized_gradients (quantizer )
696696
697- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
697+ @unittest .skipIf (_DEVICE is None , "skipping when GPU is not available" )
698+ @unittest .skipIf (
699+ _DEVICE is torch .device ("xpu" ),
700+ "skipped due to https://github.com/intel/torch-xpu-ops/issues/1770" ,
701+ )
698702 def test_qat_4w_quantizer (self ):
699703 from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
700704 from torchao .quantization .qat import Int4WeightOnlyQATQuantizer
@@ -711,8 +715,7 @@ def test_qat_4w_quantizer(self):
711715 inner_k_tiles = inner_k_tiles ,
712716 )
713717 ptq_quantizer = Int4WeightOnlyQuantizer (
714- groupsize = group_size ,
715- inner_k_tiles = inner_k_tiles ,
718+ groupsize = group_size , inner_k_tiles = inner_k_tiles , device = device
716719 )
717720 qat_model = qat_quantizer .prepare (m )
718721 ptq_model = ptq_quantizer .quantize (m2 )
@@ -1893,12 +1896,12 @@ def _test_quantize_api_against_ptq(
18931896 torch .manual_seed (self .SEED )
18941897
18951898 if module_type == "linear" :
1896- m = M ().to (dtype ).cuda ( )
1897- example_inputs = (m .example_inputs ()[0 ].to (dtype ).cuda ( ),)
1899+ m = M ().to (dtype ).to ( _DEVICE )
1900+ example_inputs = (m .example_inputs ()[0 ].to (dtype ).to ( _DEVICE ),)
18981901 filter_fn = lambda m , fqn : isinstance (m , torch .nn .Linear )
18991902 elif module_type == "embedding" :
1900- m = M3 ().to (dtype ).cuda ( )
1901- example_inputs = (m .example_inputs ()[0 ].cuda ( ),)
1903+ m = M3 ().to (dtype ).to ( _DEVICE )
1904+ example_inputs = (m .example_inputs ()[0 ].to ( _DEVICE ),)
19021905 filter_fn = lambda m , fqn : isinstance (m , torch .nn .Embedding )
19031906 else :
19041907 raise ValueError (f"Unknown module type { module_type } " )
@@ -1973,7 +1976,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat
19731976 target_convert_sqnr = float ("inf" ),
19741977 )
19751978
1976- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1979+ @unittest .skipIf (_DEVICE is None , "skipping when GPU is not available" )
19771980 def test_quantize_api_int8_int4 (self ):
19781981 """
19791982 Test the following:
@@ -1986,7 +1989,7 @@ def test_quantize_api_int8_int4(self):
19861989 target_convert_sqnr = float ("inf" ),
19871990 )
19881991
1989- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1992+ @unittest .skipIf (_DEVICE is None , "skipping when GPU is not available" )
19901993 @parametrize (
19911994 "weight_dtype, weight_granularity, dtype" ,
19921995 [
@@ -2011,7 +2014,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype):
20112014 dtype = dtype ,
20122015 )
20132016
2014- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
2017+ @unittest .skipIf (_DEVICE is None , "skipping when GPU is not available" )
20152018 @parametrize (
20162019 "weight_dtype, granularity, dtype, module_type" ,
20172020 [
0 commit comments