Skip to content

Commit 7d09e0a

Browse files
authored
[Benchmark] gather_gemv kernel and test (#635)
1 parent eff83b1 commit 7d09e0a

File tree

4 files changed

+247
-0
lines changed

4 files changed

+247
-0
lines changed

benchmarks/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,11 @@ class RunResult:
200200
"examples.welford",
201201
"welford",
202202
),
203+
"gather_gemv": (
204+
"tritonbench.operators.gather_gemv.operator",
205+
"examples.gather_gemv",
206+
"gather_gemv_tritonbench",
207+
),
203208
"int4_gemm": (
204209
"tritonbench.operators.int4_gemm.int4_gemm",
205210
"examples.int4_gemm",
@@ -305,6 +310,14 @@ class RunResult:
305310
"helion_kl_div_tritonbench-speedup": "helion_speedup",
306311
"helion_kl_div_tritonbench-accuracy": "helion_accuracy",
307312
},
313+
"gather_gemv": {
314+
"test_0-speedup": "triton_speedup",
315+
"test_0-accuracy": "triton_accuracy",
316+
"test_inductor-speedup": "torch_compile_speedup",
317+
"test_inductor-accuracy": "torch_compile_accuracy",
318+
"helion_gather_gemv_tritonbench-speedup": "helion_speedup",
319+
"helion_gather_gemv_tritonbench-accuracy": "helion_accuracy",
320+
},
308321
"int4_gemm": {
309322
"triton_int4_gemm-speedup": "triton_speedup",
310323
"triton_int4_gemm-accuracy": "triton_accuracy",

examples/gather_gemv.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
Helion Gather GEMV Kernel Example
3+
=================================
4+
This example demonstrates a Helion kernel implementation of a gather operation
5+
followed by general matrix-vector multiplication (GEMV). The operation is:
6+
w[idx].to(x.dtype) @ x, where w is a 3D tensor, idx contains indices to gather,
7+
and x is a vector.
8+
9+
Based on the tritonbench gather_gemv operator that is motivated by Mixtral performance
10+
where gather + gemv is the primary kernel.
11+
"""
12+
13+
# %%
14+
# Imports
15+
# -------
16+
from __future__ import annotations
17+
18+
from typing import TYPE_CHECKING
19+
20+
import torch
21+
from torch import Tensor
22+
23+
import helion
24+
from helion._testing import run_example
25+
import helion.language as hl
26+
27+
if TYPE_CHECKING:
28+
from collections.abc import Callable
29+
30+
31+
# %%
32+
# Gather GEMV Kernel
33+
# ------------------
34+
@helion.kernel(ignore_warnings=[helion.exc.TensorOperationInWrapper])
35+
def gather_gemv(w: Tensor, idx: Tensor, x: Tensor) -> Tensor:
36+
"""
37+
Performs a gather operation on w using idx, then matrix-vector multiplication with x.
38+
39+
Args:
40+
w (Tensor): Weight matrix of shape [B, S, S] where B is batch size, S is sequence length.
41+
idx (Tensor): Index tensor of shape [N] containing indices to gather from dimension 0 of w.
42+
x (Tensor): Vector of shape [S] to multiply with the gathered matrices.
43+
44+
Returns:
45+
Tensor: Result of shape [N, S] where each row i is w[idx[i]] @ x.
46+
"""
47+
B, S1, S2 = w.size()
48+
N = idx.size(0)
49+
S = x.size(0)
50+
assert S1 == S2, f"Weight matrix must be square, got {S1} != {S2}"
51+
assert S == S1, f"Vector size {S} must match matrix size {S1}"
52+
53+
# Rearrange shapes for matrix-vector multiplication
54+
w_view = w.contiguous().view(B * S, S).to(x.dtype) # Shape: [N, S, S]
55+
x = x.view(S, 1)
56+
57+
# Create output tensor
58+
out = torch.empty([N * S, 1], dtype=x.dtype, device=x.device)
59+
60+
# Perform matrix-vector multiplication for each gathered matrix
61+
for tile_n_s in hl.tile(N * S):
62+
acc = hl.zeros([tile_n_s, 1], dtype=torch.float32)
63+
idx_id = tile_n_s.index // S
64+
idx_gather = idx[idx_id]
65+
for tile_k in hl.tile(S):
66+
# Matrix-vector multiplication
67+
gathered = w_view[idx_gather * S + tile_n_s.index % S, tile_k]
68+
acc += hl.dot(gathered, x[tile_k, :])
69+
out[tile_n_s, :] = acc
70+
71+
return out.contiguous().view(N, S)
72+
73+
74+
# %%
75+
# Verification Function
76+
# ---------------------
77+
def check(B: int, S: int, N: int) -> None:
78+
"""
79+
Verify the gather_gemv kernel implementation against PyTorch's baseline.
80+
81+
Args:
82+
B (int): Batch size for weight matrix.
83+
S (int): Sequence length (matrix size).
84+
N (int): Number of indices to gather.
85+
"""
86+
# Create test tensors matching tritonbench format
87+
w = torch.randn((B, S, S), device="cuda:0", dtype=torch.float16)
88+
idx = torch.randint(0, B, [N], device="cuda:0", dtype=torch.int32)
89+
x = torch.randn((S), device="cuda:0", dtype=torch.float16)
90+
91+
def baseline_gather_gemv(w: Tensor, idx: Tensor, x: Tensor) -> Tensor:
92+
"""PyTorch baseline implementation."""
93+
outputs = []
94+
for idx_val in idx.tolist():
95+
outputs.append(w[idx_val].to(x.dtype) @ x)
96+
return torch.stack(outputs, dim=0)
97+
98+
run_example(gather_gemv, baseline_gather_gemv, (w, idx, x))
99+
100+
101+
# %%
102+
# Tritonbench Integration
103+
# -----------------------
104+
def gather_gemv_tritonbench(
105+
tb_op: object, w: Tensor, idx: Tensor, x: Tensor
106+
) -> Callable:
107+
"""
108+
Wrapper for tritonbench that matches its interface.
109+
110+
Args:
111+
w (Tensor): Weight matrix of shape [B, S, S].
112+
idx (Tensor): Index tensor of shape [N].
113+
x (Tensor): Vector of shape [S].
114+
115+
Returns:
116+
Callable: A callable that runs the gather_gemv kernel.
117+
"""
118+
return lambda: gather_gemv(w, idx, x)
119+
120+
121+
# %%
122+
# Main Function
123+
# -------------
124+
def main() -> None:
125+
"""
126+
Main entry point that runs the gather_gemv kernel verification.
127+
Uses sizes similar to tritonbench for consistency.
128+
"""
129+
# Test with sizes from tritonbench
130+
B = 8 # Batch size, could be number of experts in MoE
131+
N = 2 # Number of indices, experts selected
132+
for i in range(11, 15):
133+
S = 2**i
134+
print(f"Testing with B={B}, S={S}, N={N}")
135+
check(B, S, N)
136+
137+
138+
# %%
139+
if __name__ == "__main__":
140+
main()

test/test_examples.expected

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,76 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
841841
_launcher(_helion_fp8_gemm, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
842842
return out
843843

844+
--- assertExpectedJournal(TestExamples.test_gather_gemv)
845+
from __future__ import annotations
846+
847+
import torch
848+
import triton
849+
import triton.language as tl
850+
from torch._inductor.runtime.triton_compat import libdevice
851+
from helion.runtime import default_launcher as _default_launcher
852+
853+
@triton.jit
854+
def _helion_gather_gemv(out, idx, w_view, x, out_size_0, idx_stride_0, out_stride_0, w_view_stride_0, w_view_stride_1, x_stride_0, S1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
855+
pid_0 = tl.program_id(0)
856+
offset_0 = pid_0 * _BLOCK_SIZE_0
857+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
858+
mask_0 = indices_0 < out_size_0
859+
acc = tl.full([_BLOCK_SIZE_0, 1], 0.0, tl.float32)
860+
v_0 = tl.cast(S1, tl.int32)
861+
v_1 = tl.where((indices_0 < 0) != (v_0 < 0), tl.where(indices_0 % v_0 != 0, indices_0 // v_0 - 1, indices_0 // v_0), indices_0 // v_0)
862+
idx_gather = tl.load(idx + v_1 * idx_stride_0, mask_0, other=0)
863+
for offset_1 in tl.range(0, S1.to(tl.int32), _BLOCK_SIZE_1):
864+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
865+
mask_1 = indices_1 < S1
866+
idx_gather_copy = idx_gather
867+
acc_copy = acc
868+
idx_gather_copy_0 = idx_gather_copy
869+
acc_copy_0 = acc_copy
870+
v_2 = tl.cast(S1, tl.int32)
871+
v_3 = idx_gather_copy_0 * v_2
872+
v_4 = tl.cast(S1, tl.int32)
873+
v_5 = indices_0 % v_4
874+
v_6 = tl.full([], 0, tl.int32)
875+
v_7 = v_5 != v_6
876+
v_8 = libdevice.signbit(v_5) != 0 if v_5.dtype is tl.float32 else v_5 < 0
877+
v_9 = libdevice.signbit(v_4) != 0 if v_4.dtype is tl.float32 else v_4 < 0
878+
v_10 = v_8 != v_9
879+
v_11 = v_7 & v_10
880+
v_12 = v_5 + v_4
881+
v_13 = tl.where(v_11, v_12, v_5)
882+
v_14 = v_3 + v_13
883+
gathered = tl.load(w_view + (v_14[:, None] * w_view_stride_0 + indices_1[None, :] * w_view_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
884+
load_1 = tl.load(x + indices_1[:, None] * x_stride_0, mask_1[:, None], other=0)
885+
dot = tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.dot(tl.cast(gathered, tl.float32), tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]))), [0, 2, 1]), [16, 8]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]))), [0, 2, 1]), [16, 8]))), [0, 2, 1]), [16, 16]), input_precision='tf32', out_dtype=tl.float32), [16, 2, 8]), [0, 2, 1]))[0], [16, 2, 4]), [0, 2, 1]))[0], [16, 2, 2]), [0, 2, 1]))[0], [16, 2, 1]), [0, 2, 1]))[0]
886+
acc = acc_copy_0 + dot
887+
tl.store(out + indices_0[:, None] * out_stride_0, acc, mask_0[:, None])
888+
889+
def gather_gemv(w: Tensor, idx: Tensor, x: Tensor, *, _launcher=_default_launcher):
890+
"""
891+
Performs a gather operation on w using idx, then matrix-vector multiplication with x.
892+
893+
Args:
894+
w (Tensor): Weight matrix of shape [B, S, S] where B is batch size, S is sequence length.
895+
idx (Tensor): Index tensor of shape [N] containing indices to gather from dimension 0 of w.
896+
x (Tensor): Vector of shape [S] to multiply with the gathered matrices.
897+
898+
Returns:
899+
Tensor: Result of shape [N, S] where each row i is w[idx[i]] @ x.
900+
"""
901+
B, S1, S2 = w.size()
902+
N = idx.size(0)
903+
S = x.size(0)
904+
assert S1 == S2, f'Weight matrix must be square, got {S1} != {S2}'
905+
assert S == S1, f'Vector size {S} must match matrix size {S1}'
906+
w_view = w.contiguous().view(B * S, S).to(x.dtype)
907+
x = x.view(S, 1)
908+
out = torch.empty([N * S, 1], dtype=x.dtype, device=x.device)
909+
_BLOCK_SIZE_0 = 16
910+
_BLOCK_SIZE_1 = 16
911+
_launcher(_helion_gather_gemv, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, idx, w_view, x, out.size(0), idx.stride(0), out.stride(0), w_view.stride(0), w_view.stride(1), x.stride(0), S1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=8, num_stages=1)
912+
return out.contiguous().view(N, S)
913+
844914
--- assertExpectedJournal(TestExamples.test_geglu)
845915
from __future__ import annotations
846916

test/test_examples.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from helion._testing import TestCase
1313
from helion._testing import check_example
1414
from helion._testing import import_path
15+
from helion._testing import is_cuda
1516
from helion._testing import skipIfRefEager
1617
from helion._testing import skipIfRocm
1718

@@ -1184,6 +1185,29 @@ def test_kl_div(self):
11841185
)
11851186
)
11861187

1188+
def test_gather_gemv(self):
1189+
args = (
1190+
torch.randn([8, 1024, 1024], device=DEVICE, dtype=torch.float32),
1191+
torch.randint(0, 8, [2], device=DEVICE, dtype=torch.int32),
1192+
torch.randn([1024], device=DEVICE, dtype=torch.float32),
1193+
)
1194+
1195+
def expected(w, idx, x):
1196+
return w[idx].to(x.dtype) @ x
1197+
1198+
code = check_example(
1199+
"gather_gemv",
1200+
args,
1201+
expected(*args),
1202+
fn_name="gather_gemv",
1203+
block_sizes=[16, 16],
1204+
num_warps=8,
1205+
num_stages=1,
1206+
)
1207+
1208+
if is_cuda():
1209+
self.assertExpectedJournal(code)
1210+
11871211
def test_int4_gemm(self):
11881212
# Matrix dimensions
11891213
M, K, N = 256, 512, 256

0 commit comments

Comments
 (0)