1+ from __future__ import annotations
2+
3+ import torch
4+
5+ import helion
6+ from helion ._testing import run_example
7+ import helion .language as hl
8+
9+
10+ @helion .kernel (ignore_warnings = [helion .exc .TensorOperationInWrapper ])
11+ def gather_gemv (w : torch .Tensor , idx : torch .Tensor , x : torch .Tensor ) -> torch .Tensor :
12+ batch_size , s , s2 = w .size ()
13+ num_indices = idx .size (0 )
14+ assert s == s2 , f"size mismatch { s } != { s2 } "
15+ assert x .size (0 ) == s , f"vector size mismatch { x .size (0 )} != { s } "
16+
17+ out = torch .empty ([num_indices , s ], dtype = x .dtype , device = w .device )
18+
19+ # Handle negative indices by wrapping around
20+ idx_wrapped = torch .where (idx < 0 , idx + batch_size , idx )
21+
22+ for tile_i , tile_j in hl .tile ([num_indices , s ]):
23+ acc = hl .zeros ([tile_i , tile_j ], dtype = torch .float32 )
24+ for tile_k in hl .tile (s ):
25+ # Get the indices for this tile
26+ indices_tile = idx_wrapped [tile_i ]
27+ # Gather matrix elements: w[indices_tile, tile_j, tile_k]
28+ # This gives us shape [tile_i, tile_j, tile_k]
29+ # We need to sum over k dimension
30+ w_tile = w [indices_tile , tile_j , tile_k ]
31+ x_tile = x [tile_k ]
32+ # Multiply and accumulate
33+ # w_tile has shape [tile_i, tile_j, tile_k]
34+ # x_tile has shape [tile_k]
35+ # We want to compute: out[tile_i, tile_j] += sum_k(w_tile[:, :, k] * x_tile[k])
36+ acc = acc + (w_tile .to (torch .float32 ) * x_tile .to (torch .float32 )).sum (dim = - 1 )
37+ out [tile_i , tile_j ] = acc .to (x .dtype )
38+
39+ return out
40+
41+
42+ def gather_gemv_tritonbench (
43+ w : torch .Tensor , idx : torch .Tensor , x : torch .Tensor
44+ ) -> torch .Tensor :
45+ """Wrapper for tritonbench that matches its interface."""
46+ return gather_gemv (w , idx , x )
47+
48+
49+ def gather_gemv_ref (w : torch .Tensor , idx : torch .Tensor , x : torch .Tensor ) -> torch .Tensor :
50+ """Reference implementation for testing."""
51+ return w [idx ].to (x .dtype ) @ x
52+
53+
54+ def main () -> None :
55+ s = 2048
56+ # Create int8 tensor by generating float and converting
57+ w_float = torch .randn ([8 , s , s ], device = "cuda" , dtype = torch .float32 )
58+ w = (w_float * 127 ).to (torch .int8 )
59+ idx = torch .randint (0 , 8 , [2 ], device = "cuda" , dtype = torch .int64 )
60+ x = torch .randn ([s ], device = "cuda" , dtype = torch .bfloat16 )
61+
62+ run_example (gather_gemv , gather_gemv_ref , (w , idx , x ), atol = 1e-2 , rtol = 1e-2 )
63+
64+
65+ if __name__ == "__main__" :
66+ main ()
0 commit comments