Skip to content

Commit 1b80794

Browse files
zxd1997066liangan1
andauthored
[Intel XPU] Add common code for UTs on XPU (#3335)
* add common code for xpu * fix format issue * remove case * refine the xpu skip func * change auto_device_check to get_current_accelerator_device --------- Co-authored-by: Zhang, Liangang <[email protected]>
1 parent e110227 commit 1b80794

File tree

3 files changed

+81
-1
lines changed

3 files changed

+81
-1
lines changed

.github/scripts/ci_test_xpu.sh

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,14 @@ python3 -c "import torch; import torchao; print(f'Torch version: {torch.__versio
1414

1515
pip install pytest expecttest parameterized accelerate hf_transfer 'modelscope!=1.15.0'
1616

17-
pytest -v -s torchao/test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py
17+
pytest -v -s torchao/test/quantization/
18+
19+
pytest -v -s torchao/test/dtypes/
20+
21+
pytest -v -s torchao/test/float8/
22+
23+
pytest -v -s torchao/test/integration/test_integration.py
24+
25+
pytest -v -s torchao/test/prototype/
26+
27+
pytest -v -s torchao/test/test_ao_models.py

torchao/testing/utils.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,69 @@ def wrapper(*args, **kwargs):
9898
return decorator
9999

100100

101+
def skip_if_no_xpu():
102+
try:
103+
import pytest
104+
105+
has_pytest = True
106+
except ImportError:
107+
has_pytest = False
108+
import unittest
109+
110+
def decorator(func):
111+
@functools.wraps(func)
112+
def wrapper(*args, **kwargs):
113+
if not torch.xpu.is_available():
114+
skip_message = "No XPU available"
115+
if has_pytest:
116+
pytest.skip(skip_message)
117+
else:
118+
unittest.skip(skip_message)
119+
return func(*args, **kwargs)
120+
121+
return wrapper
122+
123+
return decorator
124+
125+
126+
def skip_if_xpu(message=None):
127+
"""
128+
Decorator to skip tests on XPU platform with custom message.
129+
130+
Args:
131+
message (str, optional): Additional information about why the test is skipped.
132+
"""
133+
try:
134+
import pytest
135+
136+
has_pytest = True
137+
except ImportError:
138+
has_pytest = False
139+
import unittest
140+
141+
def decorator(func):
142+
@functools.wraps(func)
143+
def wrapper(*args, **kwargs):
144+
if torch.xpu.is_available():
145+
skip_message = "Skipping the test in XPU"
146+
if message:
147+
skip_message += f": {message}"
148+
if has_pytest:
149+
pytest.skip(skip_message)
150+
else:
151+
unittest.skip(skip_message)
152+
return func(*args, **kwargs)
153+
154+
return wrapper
155+
156+
# Handle both @skip_if_xpu and @skip_if_xpu() syntax
157+
if callable(message):
158+
func = message
159+
message = None
160+
return decorator(func)
161+
return decorator
162+
163+
101164
def skip_if_no_cuda():
102165
import unittest
103166

torchao/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,13 @@ def get_available_devices():
137137
return devices
138138

139139

140+
def get_current_accelerator_device():
141+
if torch.accelerator.is_available():
142+
return torch.accelerator.current_accelerator()
143+
else:
144+
return None
145+
146+
140147
def get_compute_capability():
141148
if torch.cuda.is_available():
142149
capability = torch.cuda.get_device_capability()

0 commit comments

Comments
 (0)