Skip to content

Commit 1716c26

Browse files
authored
Generalize test_print and test_tensor_descriptor to support different accelerators (#801)
1 parent ca832fa commit 1716c26

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

test/test_print.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,8 @@ def run_kernel_and_capture_output(self, kernel_fn, args):
3838
code, result = code_and_output(kernel_fn, args)
3939

4040
# Wait for any device prints to reach the host
41-
if hasattr(result, "device") and result.device.type == "cuda":
42-
torch.cuda.synchronize()
41+
if hasattr(result, "device") and result.device.type == DEVICE.type:
42+
torch.accelerator.synchronize()
4343

4444
# Grab what pytest captured: stdout + stderr
4545
out, err = self._capfd.readouterr()
@@ -69,8 +69,8 @@ def run_kernel_and_capture_output(self, kernel_fn, args):
6969
code, result = code_and_output(kernel_fn, args)
7070

7171
# Force GPU synchronization to ensure all device prints complete
72-
if hasattr(result, "device") and result.device.type == "cuda":
73-
torch.cuda.synchronize()
72+
if hasattr(result, "device") and result.device.type == DEVICE.type:
73+
torch.accelerator.synchronize()
7474

7575
# Ensure all output is flushed
7676
sys.stdout.flush()

test/test_tensor_descriptor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def matmul(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
240240
y = torch.randn((64, 64), device=DEVICE, dtype=torch.float16)
241241

242242
code, result = code_and_output(matmul, (x, y))
243-
torch.cuda.synchronize()
243+
torch.accelerator.synchronize()
244244
expected = torch.matmul(x, y)
245245
torch.testing.assert_close(result, expected, atol=1e-2, rtol=1e-2)
246246

0 commit comments

Comments
 (0)