Skip to content

Commit 6883039

Browse files
committed
Support mx_tensor and enable it's test on Intel GPU
1 parent f303f4c commit 6883039

File tree

3 files changed

+173
-75
lines changed

3 files changed

+173
-75
lines changed

test/prototype/mx_formats/test_inference_workflow.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,14 @@
3636
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
3737

3838

39+
devices = []
40+
if torch.cuda.is_available():
41+
devices.append("cuda")
42+
43+
if torch.xpu.is_available():
44+
devices.append("xpu")
45+
46+
3947
# source: https://stackoverflow.com/a/22638709
4048
@pytest.fixture(autouse=True)
4149
def run_around_tests():
@@ -63,37 +71,50 @@ def cuda_kernel_profiler(kernel_pattern):
6371
result["found"] = any(kernel_pattern in name for name in kernel_names)
6472

6573

66-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
74+
@pytest.mark.skipif(
75+
not (torch.cuda.is_available() or torch.xpu.is_available()),
76+
reason="CUDA or XPU not available",
77+
)
6778
@pytest.mark.skipif(
6879
not torch_version_at_least("2.8.0"), reason="torch.compile requires PyTorch 2.8+"
6980
)
7081
@pytest.mark.parametrize("elem_dtype", [torch.float8_e4m3fn, torch.float4_e2m1fn_x2])
7182
@pytest.mark.parametrize("bias", [True, False])
7283
@pytest.mark.parametrize("compile", [True, False])
73-
@pytest.mark.parametrize("emulate", [True, False])
84+
@pytest.mark.parametrize(
85+
"emulate", [True, False] if (not torch.xpu.is_available()) else [True]
86+
)
87+
@pytest.mark.parametrize("device", devices)
7488
@torch.no_grad()
7589
@skip_if_rocm(
7690
"ROCm float4 gemm require gfx950"
7791
) # TODO(future): deploy gfx950 in ROCM CI
78-
def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: bool):
92+
def test_inference_workflow_mx(
93+
elem_dtype, bias: bool, compile: bool, emulate: bool, device
94+
):
7995
"""
8096
Smoke test for inference compile
8197
"""
8298
# TODO(future): figure out why these CUDA capability conditions are not properly
8399
# applied when inside `pytest.mark.skipif` for this test
84-
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
100+
if (
101+
elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2)
102+
) and torch.cuda.is_available():
85103
if not is_sm_at_least_89():
86104
pytest.skip("CUDA capability >= 8.9 required for float8 in triton")
87105
elif not is_sm_at_least_100() and not emulate:
88106
pytest.skip("CUDA capability >= 10.0 required for mxfp8 gemm")
89-
elif elem_dtype == torch.float4_e2m1fn_x2:
107+
elif (elem_dtype == torch.float4_e2m1fn_x2) and torch.cuda.is_available():
90108
if not is_sm_at_least_100() and not emulate:
91109
pytest.skip("CUDA capability >= 10.0 required for mxfp4 gemm")
92110
elif compile:
93111
# TODO(future PR): investigate and fix this
94-
pytest.skip("mxfp4 + compile currently does not work, low SQNR")
112+
pytest.skip("mxfp4 + compile currently does not work on CUDA, low SQNR")
113+
114+
if (elem_dtype == torch.float4_e2m1fn_x2) and torch.xpu.is_available() and compile:
115+
pytest.skip("mxfp4 + compile currently does not work on XPU, low SQNR")
95116

96-
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device="cuda")
117+
m = nn.Linear(32, 128, bias=bias, dtype=torch.bfloat16, device=device)
97118
m_mx = copy.deepcopy(m)
98119

99120
if emulate:
@@ -111,7 +132,7 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b
111132
if compile:
112133
m_mx = torch.compile(m_mx, fullgraph=True)
113134

114-
x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
135+
x = torch.randn(128, 32, device=device, dtype=torch.bfloat16)
115136
y_ref = m(x)
116137
y_mx = m_mx(x)
117138
sqnr = compute_error(y_ref, y_mx)

0 commit comments

Comments
 (0)