From 5d86dd0b579d5b70e5069ed8dc202328110ab40a Mon Sep 17 00:00:00 2001 From: Adam Grabowski Date: Thu, 30 Oct 2025 13:54:18 +0000 Subject: [PATCH] [Intel GPU] Extend TestQAT module with xpu testcases Add xpu mode to tests from test_qat.py TestQAT module --- test/quantization/test_qat.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index db33561fa9..dc0e2c848a 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -653,13 +653,13 @@ def test_qat_4w_primitives(self): self._assert_close_4w(qat_out, ptq_out) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(_DEVICE is None, "skipping when GPU is not available") def test_qat_4w_linear(self): from torchao.quantization.GPTQ import WeightOnlyInt4Linear from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear group_size = 128 - device = torch.device("cuda") + device = torch.device(_DEVICE) dtype = torch.bfloat16 torch.manual_seed(self.SEED) qat_linear = Int4WeightOnlyQATLinear( @@ -694,7 +694,11 @@ def test_qat_4w_quantizer_gradients(self): quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8) self._test_qat_quantized_gradients(quantizer) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(_DEVICE is None, "skipping when GPU is not available") + @unittest.skipIf( + _DEVICE is torch.device("xpu"), + "skipped due to https://github.com/intel/torch-xpu-ops/issues/1770", + ) def test_qat_4w_quantizer(self): from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer from torchao.quantization.qat import Int4WeightOnlyQATQuantizer @@ -711,8 +715,7 @@ def test_qat_4w_quantizer(self): inner_k_tiles=inner_k_tiles, ) ptq_quantizer = Int4WeightOnlyQuantizer( - groupsize=group_size, - inner_k_tiles=inner_k_tiles, + groupsize=group_size, inner_k_tiles=inner_k_tiles, device=device ) qat_model = qat_quantizer.prepare(m) ptq_model = ptq_quantizer.quantize(m2) @@ -1893,12 +1896,12 @@ def _test_quantize_api_against_ptq( torch.manual_seed(self.SEED) if module_type == "linear": - m = M().to(dtype).cuda() - example_inputs = (m.example_inputs()[0].to(dtype).cuda(),) + m = M().to(dtype).to(_DEVICE) + example_inputs = (m.example_inputs()[0].to(dtype).to(_DEVICE),) filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear) elif module_type == "embedding": - m = M3().to(dtype).cuda() - example_inputs = (m.example_inputs()[0].cuda(),) + m = M3().to(dtype).to(_DEVICE) + example_inputs = (m.example_inputs()[0].to(_DEVICE),) filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding) else: raise ValueError(f"Unknown module type {module_type}") @@ -1973,7 +1976,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(_DEVICE is None, "skipping when GPU is not available") def test_quantize_api_int8_int4(self): """ Test the following: @@ -1986,7 +1989,7 @@ def test_quantize_api_int8_int4(self): target_convert_sqnr=float("inf"), ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(_DEVICE is None, "skipping when GPU is not available") @parametrize( "weight_dtype, weight_granularity, dtype", [ @@ -2011,7 +2014,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype): dtype=dtype, ) - @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + @unittest.skipIf(_DEVICE is None, "skipping when GPU is not available") @parametrize( "weight_dtype, granularity, dtype, module_type", [