Skip to content

Commit 5d86dd0

Browse files
committed
[Intel GPU] Extend TestQAT module with xpu testcases
Add xpu mode to tests from test_qat.py TestQAT module
1 parent 4f5bc7a commit 5d86dd0

File tree

1 file changed

+15
-12
lines changed

1 file changed

+15
-12
lines changed

test/quantization/test_qat.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -653,13 +653,13 @@ def test_qat_4w_primitives(self):
653653

654654
self._assert_close_4w(qat_out, ptq_out)
655655

656-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
656+
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
657657
def test_qat_4w_linear(self):
658658
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
659659
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear
660660

661661
group_size = 128
662-
device = torch.device("cuda")
662+
device = torch.device(_DEVICE)
663663
dtype = torch.bfloat16
664664
torch.manual_seed(self.SEED)
665665
qat_linear = Int4WeightOnlyQATLinear(
@@ -694,7 +694,11 @@ def test_qat_4w_quantizer_gradients(self):
694694
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
695695
self._test_qat_quantized_gradients(quantizer)
696696

697-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
697+
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
698+
@unittest.skipIf(
699+
_DEVICE is torch.device("xpu"),
700+
"skipped due to https://github.com/intel/torch-xpu-ops/issues/1770",
701+
)
698702
def test_qat_4w_quantizer(self):
699703
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
700704
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
@@ -711,8 +715,7 @@ def test_qat_4w_quantizer(self):
711715
inner_k_tiles=inner_k_tiles,
712716
)
713717
ptq_quantizer = Int4WeightOnlyQuantizer(
714-
groupsize=group_size,
715-
inner_k_tiles=inner_k_tiles,
718+
groupsize=group_size, inner_k_tiles=inner_k_tiles, device=device
716719
)
717720
qat_model = qat_quantizer.prepare(m)
718721
ptq_model = ptq_quantizer.quantize(m2)
@@ -1893,12 +1896,12 @@ def _test_quantize_api_against_ptq(
18931896
torch.manual_seed(self.SEED)
18941897

18951898
if module_type == "linear":
1896-
m = M().to(dtype).cuda()
1897-
example_inputs = (m.example_inputs()[0].to(dtype).cuda(),)
1899+
m = M().to(dtype).to(_DEVICE)
1900+
example_inputs = (m.example_inputs()[0].to(dtype).to(_DEVICE),)
18981901
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear)
18991902
elif module_type == "embedding":
1900-
m = M3().to(dtype).cuda()
1901-
example_inputs = (m.example_inputs()[0].cuda(),)
1903+
m = M3().to(dtype).to(_DEVICE)
1904+
example_inputs = (m.example_inputs()[0].to(_DEVICE),)
19021905
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding)
19031906
else:
19041907
raise ValueError(f"Unknown module type {module_type}")
@@ -1973,7 +1976,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat
19731976
target_convert_sqnr=float("inf"),
19741977
)
19751978

1976-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1979+
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
19771980
def test_quantize_api_int8_int4(self):
19781981
"""
19791982
Test the following:
@@ -1986,7 +1989,7 @@ def test_quantize_api_int8_int4(self):
19861989
target_convert_sqnr=float("inf"),
19871990
)
19881991

1989-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1992+
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
19901993
@parametrize(
19911994
"weight_dtype, weight_granularity, dtype",
19921995
[
@@ -2011,7 +2014,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype):
20112014
dtype=dtype,
20122015
)
20132016

2014-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
2017+
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
20152018
@parametrize(
20162019
"weight_dtype, granularity, dtype, module_type",
20172020
[

0 commit comments

Comments
 (0)