Skip to content

Commit 3eeacdf

Browse files
committed
add skip_rocm arg in assertExpectedJournal
1 parent c8421c3 commit 3eeacdf

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

helion/_testing.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,14 @@ def is_cuda() -> bool:
6868
)
6969

7070

71+
def is_rocm() -> bool:
72+
"""Return True if running on ROCm (AMD GPU)."""
73+
return (
74+
triton.runtime.driver.active.get_current_target().backend == "hip" # pyright: ignore[reportAttributeAccessIssue]
75+
and DEVICE.type == "cuda"
76+
)
77+
78+
7179
@contextlib.contextmanager
7280
def track_run_ref_calls() -> Generator[list[int], None, None]:
7381
"""Context manager that tracks BoundKernel.run_ref calls.
@@ -329,8 +337,8 @@ def tearDown(self) -> None:
329337

330338
# NOTE: We no-op these methods because they commonly check behaviors that are not relevant in ref eager mode.
331339
# Instead, we solely rely on the unit test's `torch.testing.assert_close` and `assertRaises` checks to ensure ref eager mode's correctness.
332-
def assertExpectedJournal(self, value: str) -> None:
333-
if not self._in_ref_eager_mode:
340+
def assertExpectedJournal(self, value: str, skip_rocm: bool = False) -> None:
341+
if not self._in_ref_eager_mode and not (skip_rocm and is_rocm()):
334342
super().assertExpectedJournal(value) # type: ignore[misc]
335343

336344
def assertIn(

test/test_examples.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1112,7 +1112,6 @@ def test_kl_div(self):
11121112
)
11131113
)
11141114

1115-
@skipIfRocm("failure on rocm")
11161115
def test_gather_gemv(self):
11171116
args = (
11181117
torch.randn([8, 1024, 1024], device=DEVICE, dtype=torch.float32),
@@ -1132,7 +1131,8 @@ def expected(w, idx, x):
11321131
block_sizes=[16, 16],
11331132
num_warps=8,
11341133
num_stages=1,
1135-
)
1134+
),
1135+
skip_rocm=True,
11361136
)
11371137

11381138
def test_int4_gemm(self):

0 commit comments

Comments
 (0)