3636 pytest .skip ("Unsupported PyTorch version" , allow_module_level = True )
3737
3838
39+ devices = []
40+ if torch .cuda .is_available ():
41+ devices .append ("cuda" )
42+
43+ if torch .xpu .is_available ():
44+ devices .append ("xpu" )
45+
46+
3947# source: https://stackoverflow.com/a/22638709
4048@pytest .fixture (autouse = True )
4149def run_around_tests ():
@@ -63,37 +71,50 @@ def cuda_kernel_profiler(kernel_pattern):
6371 result ["found" ] = any (kernel_pattern in name for name in kernel_names )
6472
6573
66- @pytest .mark .skipif (not torch .cuda .is_available (), reason = "CUDA not available" )
74+ @pytest .mark .skipif (
75+ not (torch .cuda .is_available () or torch .xpu .is_available ()),
76+ reason = "CUDA or XPU not available" ,
77+ )
6778@pytest .mark .skipif (
6879 not torch_version_at_least ("2.8.0" ), reason = "torch.compile requires PyTorch 2.8+"
6980)
7081@pytest .mark .parametrize ("elem_dtype" , [torch .float8_e4m3fn , torch .float4_e2m1fn_x2 ])
7182@pytest .mark .parametrize ("bias" , [True , False ])
7283@pytest .mark .parametrize ("compile" , [True , False ])
73- @pytest .mark .parametrize ("emulate" , [True , False ])
84+ @pytest .mark .parametrize (
85+ "emulate" , [True , False ] if (not torch .xpu .is_available ()) else [True ]
86+ )
87+ @pytest .mark .parametrize ("device" , devices )
7488@torch .no_grad ()
7589@skip_if_rocm (
7690 "ROCm float4 gemm require gfx950"
7791) # TODO(future): deploy gfx950 in ROCM CI
78- def test_inference_workflow_mx (elem_dtype , bias : bool , compile : bool , emulate : bool ):
92+ def test_inference_workflow_mx (
93+ elem_dtype , bias : bool , compile : bool , emulate : bool , device
94+ ):
7995 """
8096 Smoke test for inference compile
8197 """
8298 # TODO(future): figure out why these CUDA capability conditions are not properly
8399 # applied when inside `pytest.mark.skipif` for this test
84- if elem_dtype in (torch .float8_e4m3fn , torch .float8_e5m2 ):
100+ if (
101+ elem_dtype in (torch .float8_e4m3fn , torch .float8_e5m2 )
102+ ) and torch .cuda .is_available ():
85103 if not is_sm_at_least_89 ():
86104 pytest .skip ("CUDA capability >= 8.9 required for float8 in triton" )
87105 elif not is_sm_at_least_100 () and not emulate :
88106 pytest .skip ("CUDA capability >= 10.0 required for mxfp8 gemm" )
89- elif elem_dtype == torch .float4_e2m1fn_x2 :
107+ elif ( elem_dtype == torch .float4_e2m1fn_x2 ) and torch . cuda . is_available () :
90108 if not is_sm_at_least_100 () and not emulate :
91109 pytest .skip ("CUDA capability >= 10.0 required for mxfp4 gemm" )
92110 elif compile :
93111 # TODO(future PR): investigate and fix this
94- pytest .skip ("mxfp4 + compile currently does not work, low SQNR" )
112+ pytest .skip ("mxfp4 + compile currently does not work on CUDA, low SQNR" )
113+
114+ if (elem_dtype == torch .float4_e2m1fn_x2 ) and torch .xpu .is_available () and compile :
115+ pytest .skip ("mxfp4 + compile currently does not work on XPU, low SQNR" )
95116
96- m = nn .Linear (32 , 128 , bias = bias , dtype = torch .bfloat16 , device = "cuda" )
117+ m = nn .Linear (32 , 128 , bias = bias , dtype = torch .bfloat16 , device = device )
97118 m_mx = copy .deepcopy (m )
98119
99120 if emulate :
@@ -111,7 +132,7 @@ def test_inference_workflow_mx(elem_dtype, bias: bool, compile: bool, emulate: b
111132 if compile :
112133 m_mx = torch .compile (m_mx , fullgraph = True )
113134
114- x = torch .randn (128 , 32 , device = "cuda" , dtype = torch .bfloat16 )
135+ x = torch .randn (128 , 32 , device = device , dtype = torch .bfloat16 )
115136 y_ref = m (x )
116137 y_mx = m_mx (x )
117138 sqnr = compute_error (y_ref , y_mx )
0 commit comments