@@ -841,6 +841,76 @@ def fp8_gemm(x: torch.Tensor, y: torch.Tensor, *, _launcher=_default_launcher):
841841 _launcher(_helion_fp8_gemm, (triton.cdiv(256, _BLOCK_SIZE_0) * triton.cdiv(256, _BLOCK_SIZE_1),), x, y, out, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
842842 return out
843843
844+ --- assertExpectedJournal(TestExamples.test_gather_gemv)
845+ from __future__ import annotations
846+
847+ import torch
848+ import triton
849+ import triton.language as tl
850+ from torch._inductor.runtime.triton_compat import libdevice
851+ from helion.runtime import default_launcher as _default_launcher
852+
853+ @triton.jit
854+ def _helion_gather_gemv(out, idx, w_view, x, out_size_0, idx_stride_0, out_stride_0, w_view_stride_0, w_view_stride_1, x_stride_0, S1, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr):
855+ pid_0 = tl.program_id(0)
856+ offset_0 = pid_0 * _BLOCK_SIZE_0
857+ indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
858+ mask_0 = indices_0 < out_size_0
859+ acc = tl.full([_BLOCK_SIZE_0, 1], 0.0, tl.float32)
860+ v_0 = tl.cast(S1, tl.int32)
861+ v_1 = tl.where((indices_0 < 0) != (v_0 < 0), tl.where(indices_0 % v_0 != 0, indices_0 // v_0 - 1, indices_0 // v_0), indices_0 // v_0)
862+ idx_gather = tl.load(idx + v_1 * idx_stride_0, mask_0, other=0)
863+ for offset_1 in tl.range(0, S1.to(tl.int32), _BLOCK_SIZE_1):
864+ indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
865+ mask_1 = indices_1 < S1
866+ idx_gather_copy = idx_gather
867+ acc_copy = acc
868+ idx_gather_copy_0 = idx_gather_copy
869+ acc_copy_0 = acc_copy
870+ v_2 = tl.cast(S1, tl.int32)
871+ v_3 = idx_gather_copy_0 * v_2
872+ v_4 = tl.cast(S1, tl.int32)
873+ v_5 = indices_0 % v_4
874+ v_6 = tl.full([], 0, tl.int32)
875+ v_7 = v_5 != v_6
876+ v_8 = libdevice.signbit(v_5) != 0 if v_5.dtype is tl.float32 else v_5 < 0
877+ v_9 = libdevice.signbit(v_4) != 0 if v_4.dtype is tl.float32 else v_4 < 0
878+ v_10 = v_8 != v_9
879+ v_11 = v_7 & v_10
880+ v_12 = v_5 + v_4
881+ v_13 = tl.where(v_11, v_12, v_5)
882+ v_14 = v_3 + v_13
883+ gathered = tl.load(w_view + (v_14[:, None] * w_view_stride_0 + indices_1[None, :] * w_view_stride_1), mask_0[:, None] & mask_1[None, :], other=0)
884+ load_1 = tl.load(x + indices_1[:, None] * x_stride_0, mask_1[:, None], other=0)
885+ dot = tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.split(tl.permute(tl.reshape(tl.dot(tl.cast(gathered, tl.float32), tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]))), [0, 2, 1]), [16, 8]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]), tl.zeros_like(tl.reshape(tl.permute(tl.join(tl.cast(load_1, tl.float32), tl.zeros_like(tl.cast(load_1, tl.float32))), [0, 2, 1]), [16, 2]))), [0, 2, 1]), [16, 4]))), [0, 2, 1]), [16, 8]))), [0, 2, 1]), [16, 16]), input_precision='tf32', out_dtype=tl.float32), [16, 2, 8]), [0, 2, 1]))[0], [16, 2, 4]), [0, 2, 1]))[0], [16, 2, 2]), [0, 2, 1]))[0], [16, 2, 1]), [0, 2, 1]))[0]
886+ acc = acc_copy_0 + dot
887+ tl.store(out + indices_0[:, None] * out_stride_0, acc, mask_0[:, None])
888+
889+ def gather_gemv(w: Tensor, idx: Tensor, x: Tensor, *, _launcher=_default_launcher):
890+ """
891+ Performs a gather operation on w using idx, then matrix-vector multiplication with x.
892+
893+ Args:
894+ w (Tensor): Weight matrix of shape [B, S, S] where B is batch size, S is sequence length.
895+ idx (Tensor): Index tensor of shape [N] containing indices to gather from dimension 0 of w.
896+ x (Tensor): Vector of shape [S] to multiply with the gathered matrices.
897+
898+ Returns:
899+ Tensor: Result of shape [N, S] where each row i is w[idx[i]] @ x.
900+ """
901+ B, S1, S2 = w.size()
902+ N = idx.size(0)
903+ S = x.size(0)
904+ assert S1 == S2, f'Weight matrix must be square, got {S1} != {S2}'
905+ assert S == S1, f'Vector size {S} must match matrix size {S1}'
906+ w_view = w.contiguous().view(B * S, S).to(x.dtype)
907+ x = x.view(S, 1)
908+ out = torch.empty([N * S, 1], dtype=x.dtype, device=x.device)
909+ _BLOCK_SIZE_0 = 16
910+ _BLOCK_SIZE_1 = 16
911+ _launcher(_helion_gather_gemv, (triton.cdiv(out.size(0), _BLOCK_SIZE_0),), out, idx, w_view, x, out.size(0), idx.stride(0), out.stride(0), w_view.stride(0), w_view.stride(1), x.stride(0), S1, _BLOCK_SIZE_0, _BLOCK_SIZE_1, num_warps=8, num_stages=1)
912+ return out.contiguous().view(N, S)
913+
844914--- assertExpectedJournal(TestExamples.test_geglu)
845915from __future__ import annotations
846916
0 commit comments