Skip to content

Commit d413980

Browse files
committed
[Intel GPU] Extend TestQAT module with xpu testcases
Add xpu mode to tests from test_qat.py TestQAT module
1 parent 4f5bc7a commit d413980

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

test/quantization/test_qat.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -653,13 +653,13 @@ def test_qat_4w_primitives(self):
653653

654654
self._assert_close_4w(qat_out, ptq_out)
655655

656-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
656+
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
657657
def test_qat_4w_linear(self):
658658
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
659659
from torchao.quantization.qat.linear import Int4WeightOnlyQATLinear
660660

661661
group_size = 128
662-
device = torch.device("cuda")
662+
device = _DEVICE
663663
dtype = torch.bfloat16
664664
torch.manual_seed(self.SEED)
665665
qat_linear = Int4WeightOnlyQATLinear(
@@ -694,7 +694,11 @@ def test_qat_4w_quantizer_gradients(self):
694694
quantizer = Int4WeightOnlyQATQuantizer(groupsize=32, inner_k_tiles=8)
695695
self._test_qat_quantized_gradients(quantizer)
696696

697-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
697+
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
698+
@unittest.skipIf(
699+
_DEVICE is torch.device("xpu"),
700+
"skipped due to https://github.com/intel/torch-xpu-ops/issues/1770",
701+
)
698702
def test_qat_4w_quantizer(self):
699703
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
700704
from torchao.quantization.qat import Int4WeightOnlyQATQuantizer
@@ -703,6 +707,7 @@ def test_qat_4w_quantizer(self):
703707
inner_k_tiles = 8
704708
device = torch.device(_DEVICE)
705709
dtype = torch.bfloat16
710+
device = _DEVICE
706711
torch.manual_seed(self.SEED)
707712
m = M().to(device).to(dtype)
708713
m2 = copy.deepcopy(m)
@@ -711,8 +716,7 @@ def test_qat_4w_quantizer(self):
711716
inner_k_tiles=inner_k_tiles,
712717
)
713718
ptq_quantizer = Int4WeightOnlyQuantizer(
714-
groupsize=group_size,
715-
inner_k_tiles=inner_k_tiles,
719+
groupsize=group_size, inner_k_tiles=inner_k_tiles, device=device
716720
)
717721
qat_model = qat_quantizer.prepare(m)
718722
ptq_model = ptq_quantizer.quantize(m2)
@@ -1893,12 +1897,12 @@ def _test_quantize_api_against_ptq(
18931897
torch.manual_seed(self.SEED)
18941898

18951899
if module_type == "linear":
1896-
m = M().to(dtype).cuda()
1897-
example_inputs = (m.example_inputs()[0].to(dtype).cuda(),)
1900+
m = M().to(dtype).to(_DEVICE)
1901+
example_inputs = (m.example_inputs()[0].to(dtype).to(_DEVICE),)
18981902
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Linear)
18991903
elif module_type == "embedding":
1900-
m = M3().to(dtype).cuda()
1901-
example_inputs = (m.example_inputs()[0].cuda(),)
1904+
m = M3().to(dtype).to(_DEVICE)
1905+
example_inputs = (m.example_inputs()[0].to(_DEVICE),)
19021906
filter_fn = lambda m, fqn: isinstance(m, torch.nn.Embedding)
19031907
else:
19041908
raise ValueError(f"Unknown module type {module_type}")
@@ -1973,7 +1977,7 @@ def test_quantize_api_int4(self, version: int, packing_format: Int4PackingFormat
19731977
target_convert_sqnr=float("inf"),
19741978
)
19751979

1976-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1980+
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
19771981
def test_quantize_api_int8_int4(self):
19781982
"""
19791983
Test the following:
@@ -1986,7 +1990,7 @@ def test_quantize_api_int8_int4(self):
19861990
target_convert_sqnr=float("inf"),
19871991
)
19881992

1989-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
1993+
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
19901994
@parametrize(
19911995
"weight_dtype, weight_granularity, dtype",
19921996
[
@@ -2011,7 +2015,7 @@ def test_quantize_api_int8_intx(self, weight_dtype, weight_granularity, dtype):
20112015
dtype=dtype,
20122016
)
20132017

2014-
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
2018+
@unittest.skipIf(_DEVICE is None, "skipping when GPU is not available")
20152019
@parametrize(
20162020
"weight_dtype, granularity, dtype, module_type",
20172021
[

0 commit comments

Comments
 (0)