@@ -653,13 +653,13 @@ def test_qat_4w_primitives(self):
653653
654654 self ._assert_close_4w (qat_out , ptq_out )
655655
656- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
656+ @unittest .skipIf (_DEVICE is None , "skipping when GPU is not available" )
657657 def test_qat_4w_linear (self ):
658658 from torchao .quantization .GPTQ import WeightOnlyInt4Linear
659659 from torchao .quantization .qat .linear import Int4WeightOnlyQATLinear
660660
661661 group_size = 128
662- device = torch . device ( "cuda" )
662+ device = _DEVICE
663663 dtype = torch .bfloat16
664664 torch .manual_seed (self .SEED )
665665 qat_linear = Int4WeightOnlyQATLinear (
@@ -694,7 +694,11 @@ def test_qat_4w_quantizer_gradients(self):
694694 quantizer = Int4WeightOnlyQATQuantizer (groupsize = 32 , inner_k_tiles = 8 )
695695 self ._test_qat_quantized_gradients (quantizer )
696696
697- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
697+ @unittest .skipIf (_DEVICE is None , "skipping when GPU is not available" )
698+ @unittest .skipIf (
699+ _DEVICE is torch .device ("xpu" ),
700+ "skipped due to https://github.com/intel/torch-xpu-ops/issues/1770" ,
701+ )
698702 def test_qat_4w_quantizer (self ):
699703 from torchao .quantization .GPTQ import Int4WeightOnlyQuantizer
700704 from torchao .quantization .qat import Int4WeightOnlyQATQuantizer
@@ -703,6 +707,7 @@ def test_qat_4w_quantizer(self):
703707 inner_k_tiles = 8
704708 device = torch .device (_DEVICE )
705709 dtype = torch .bfloat16
710+ device = _DEVICE
706711 torch .manual_seed (self .SEED )
707712 m = M ().to (device ).to (dtype )
708713 m2 = copy .deepcopy (m )
@@ -711,8 +716,7 @@ def test_qat_4w_quantizer(self):
711716 inner_k_tiles = inner_k_tiles ,
712717 )
713718 ptq_quantizer = Int4WeightOnlyQuantizer (
714- groupsize = group_size ,
715- inner_k_tiles = inner_k_tiles ,
719+ groupsize = group_size , inner_k_tiles = inner_k_tiles , device = device
716720 )
717721 qat_model = qat_quantizer .prepare (m )
718722 ptq_model = ptq_quantizer .quantize (m2 )
@@ -1893,12 +1897,12 @@ def _test_quantize_api_against_ptq(
18931897 torch .manual_seed (self .SEED )
18941898
18951899 if module_type == "linear" :
1896- m = M ().to (dtype ).cuda ( )
1897- example_inputs = (m .example_inputs ()[0 ].to (dtype ).cuda ( ),)
1900+ m = M ().to (dtype ).to ( _DEVICE )
1901+ example_inputs = (m .example_inputs ()[0 ].to (dtype ).to ( _DEVICE ),)
18981902 filter_fn = lambda m , fqn : isinstance (m , torch .nn .Linear )
18991903 elif module_type == "embedding" :
1900- m = M3 ().to (dtype ).cuda ( )
1901- example_inputs = (m .example_inputs ()[0 ].cuda ( ),)
1904+ m = M3 ().to (dtype ).to ( _DEVICE )
1905+ example_inputs = (m .example_inputs ()[0 ].to ( _DEVICE ),)
19021906 filter_fn = lambda m , fqn : isinstance (m , torch .nn .Embedding )
19031907 else :
19041908 raise ValueError (f"Unknown module type { module_type } " )
@@ -1973,7 +1977,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat
19731977 target_convert_sqnr = float ("inf" ),
19741978 )
19751979
1976- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1980+ @unittest .skipIf (_DEVICE is None , "skipping when GPU is not available" )
19771981 def test_quantize_api_int8_int4 (self ):
19781982 """
19791983 Test the following:
@@ -1986,7 +1990,7 @@ def test_quantize_api_int8_int4(self):
19861990 target_convert_sqnr = float ("inf" ),
19871991 )
19881992
1989- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
1993+ @unittest .skipIf (_DEVICE is None , "skipping when GPU is not available" )
19901994 @parametrize (
19911995 "weight_dtype, weight_granularity, dtype" ,
19921996 [
@@ -2011,7 +2015,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype):
20112015 dtype = dtype ,
20122016 )
20132017
2014- @unittest .skipIf (not _CUDA_IS_AVAILABLE , "skipping when cuda is not available" )
2018+ @unittest .skipIf (_DEVICE is None , "skipping when GPU is not available" )
20152019 @parametrize (
20162020 "weight_dtype, granularity, dtype, module_type" ,
20172021 [
0 commit comments