Skip to content

Commit 5f1d73a

Browse files
committed
[WIP] gather_gemv
1 parent c6c41bd commit 5f1d73a

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

benchmarks/run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@
9999
"examples.welford",
100100
"welford_tritonbench",
101101
),
102+
"gather_gemv": (
103+
"tritonbench.operators.gather_gemv.operator",
104+
"examples.gather_gemv",
105+
"gather_gemv_tritonbench",
106+
),
102107
}
103108

104109

examples/gather_gemv.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion._testing import run_example
7+
import helion.language as hl
8+
9+
10+
@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
11+
def gather_gemv(w: torch.Tensor, idx: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
12+
batch_size, s, s2 = w.size()
13+
num_indices = idx.size(0)
14+
assert s == s2, f"size mismatch {s} != {s2}"
15+
assert x.size(0) == s, f"vector size mismatch {x.size(0)} != {s}"
16+
17+
out = torch.empty([num_indices, s], dtype=x.dtype, device=w.device)
18+
19+
# Handle negative indices by wrapping around
20+
idx_wrapped = torch.where(idx < 0, idx + batch_size, idx)
21+
22+
for tile_i, tile_j in hl.tile([num_indices, s]):
23+
acc = hl.zeros([tile_i, tile_j], dtype=torch.float32)
24+
for tile_k in hl.tile(s):
25+
# Get the indices for this tile
26+
indices_tile = idx_wrapped[tile_i]
27+
# Gather matrix elements: w[indices_tile, tile_j, tile_k]
28+
# This gives us shape [tile_i, tile_j, tile_k]
29+
# We need to sum over k dimension
30+
w_tile = w[indices_tile, tile_j, tile_k]
31+
x_tile = x[tile_k]
32+
# Multiply and accumulate
33+
# w_tile has shape [tile_i, tile_j, tile_k]
34+
# x_tile has shape [tile_k]
35+
# We want to compute: out[tile_i, tile_j] += sum_k(w_tile[:, :, k] * x_tile[k])
36+
acc = acc + (w_tile.to(torch.float32) * x_tile.to(torch.float32)).sum(dim=-1)
37+
out[tile_i, tile_j] = acc.to(x.dtype)
38+
39+
return out
40+
41+
42+
def gather_gemv_tritonbench(
43+
w: torch.Tensor, idx: torch.Tensor, x: torch.Tensor
44+
) -> torch.Tensor:
45+
"""Wrapper for tritonbench that matches its interface."""
46+
return gather_gemv(w, idx, x)
47+
48+
49+
def gather_gemv_ref(w: torch.Tensor, idx: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
50+
"""Reference implementation for testing."""
51+
return w[idx].to(x.dtype) @ x
52+
53+
54+
def main() -> None:
55+
s = 2048
56+
# Create int8 tensor by generating float and converting
57+
w_float = torch.randn([8, s, s], device="cuda", dtype=torch.float32)
58+
w = (w_float * 127).to(torch.int8)
59+
idx = torch.randint(0, 8, [2], device="cuda", dtype=torch.int64)
60+
x = torch.randn([s], device="cuda", dtype=torch.bfloat16)
61+
62+
run_example(gather_gemv, gather_gemv_ref, (w, idx, x), atol=1e-2, rtol=1e-2)
63+
64+
65+
if __name__ == "__main__":
66+
main()

0 commit comments

Comments
 (0)