-
Notifications
You must be signed in to change notification settings - Fork 64
[Benchmark] gather_gemv kernel and test #635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
stack-info: PR: #635, branch: Sibylau/stack/3
f0765bb to
d64b898
Compare
|
yf225
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks @Sibylau ! left some nit comments, and also might need to rebase to fix conflicts with main branch
examples/gather_gemv.py
Outdated
|
|
||
| def baseline_gather_gemv(w: Tensor, idx: Tensor, x: Tensor) -> Tensor: | ||
| """PyTorch baseline implementation.""" | ||
| # A hard-wired fix for tritonbench baseline: w[idx].to(x.dtype) @ x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe can remove this comment
examples/gather_gemv.py
Outdated
| for idx_val in idx.tolist(): | ||
| outputs.append(w[idx_val].to(x.dtype) @ x) | ||
| return torch.stack(outputs, dim=0) | ||
| # return torch.stack([w[idx[0]].to(x.dtype) @ x, w[idx[1]].to(x.dtype) @ x]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe remove?
test/test_examples.py
Outdated
| args, | ||
| expected(*args), | ||
| fn_name="gather_gemv", | ||
| block_sizes=[64, 64], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not exactly sure why AMD job fails.. but I suspect changing the block sizes to some smaller value might help
|
@yf225 The CI test on rocm fails due to code mismatch: Do you know why the generated code for AMD is different? can i put a @skipIfRocm for this kernel test? |
test/test_examples.py
Outdated
| ) | ||
| ) | ||
|
|
||
| @skipIfRocm("failure on rocm") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of using skipIfRocm which skips the whole test including the output equality check, maybe we can add a skip_rocm: bool arg to def assertExpectedJournal that skips the journal check if the device is rocm. And to check the device is rocm, we can add something like this to _testing.py:
def is_rocm() -> bool:
"""Return True if running on ROCm (AMD GPU)."""
return (
triton.runtime.driver.active.get_current_target().backend == "hip"
and DEVICE.type == "cuda"
)(Please feel free to do this in a follow-up PR. Thanks!)
stack-info: PR: #635, branch: Sibylau/stack/3
78a4c22 to
c8421c3
Compare


Stacked PRs:
[Benchmark] gather_gemv kernel and test