Skip to content

Commit 5195d38

Browse files
committed
[WIP] grouped_gemm
stack-info: PR: #305, branch: yf225/stack/21
1 parent ae365c2 commit 5195d38

File tree

2 files changed

+252
-0
lines changed

2 files changed

+252
-0
lines changed

benchmarks/run.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,11 @@
8989
"examples.ragged_attention",
9090
"ragged_attention_tritonbench",
9191
),
92+
"grouped_gemm": (
93+
"tritonbench.operators.grouped_gemm.operator",
94+
"examples.grouped_gemm",
95+
"grouped_gemm_tritonbench",
96+
),
9297
}
9398

9499

examples/grouped_gemm.py

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
"""
2+
Grouped GEMM - Multiple matrix multiplications in a single kernel launch
3+
"""
4+
5+
from __future__ import annotations
6+
7+
import torch
8+
9+
import helion
10+
from helion._testing import run_example
11+
import helion.language as hl
12+
13+
14+
@helion.kernel(static_shapes=False)
15+
def grouped_gemm(
16+
A_concat: torch.Tensor, # Concatenated A matrices [total_M, K]
17+
B_concat: torch.Tensor, # Concatenated B matrices [K, total_N]
18+
group_sizes_M: torch.Tensor, # [G] - M size for each GEMM
19+
group_sizes_N: torch.Tensor, # [G] - N size for each GEMM
20+
group_offsets_M: torch.Tensor, # [G+1] - Starting M offset for each GEMM
21+
group_offsets_N: torch.Tensor, # [G+1] - Starting N offset for each GEMM
22+
max_M_tensor: torch.Tensor, # Dummy tensor of size max(M)
23+
max_N_tensor: torch.Tensor, # Dummy tensor of size max(N)
24+
) -> torch.Tensor: # [total_M, total_N] - Concatenated output
25+
"""Grouped GEMM kernel using concatenated tensors"""
26+
G = group_sizes_M.shape[0]
27+
total_M, K = A_concat.shape
28+
_, total_N = B_concat.shape
29+
max_M = max_M_tensor.numel()
30+
max_N = max_N_tensor.numel()
31+
32+
# Allocate output tensor
33+
C_concat = torch.zeros(
34+
total_M, total_N,
35+
dtype=torch.promote_types(A_concat.dtype, B_concat.dtype),
36+
device=A_concat.device
37+
)
38+
39+
# Process each GEMM
40+
for g_idx in hl.grid(G):
41+
# Get dimensions and offsets for this GEMM
42+
M = group_sizes_M[g_idx]
43+
N = group_sizes_N[g_idx]
44+
M_start = group_offsets_M[g_idx]
45+
N_start = group_offsets_N[g_idx]
46+
47+
# Skip empty GEMMs
48+
valid_gemm = (M > 0) * (N > 0) # Use multiplication instead of 'and'
49+
if valid_gemm:
50+
# Tile over output dimensions
51+
for tile_m, tile_n in hl.tile([max_M, max_N]):
52+
# Get tile indices
53+
m_indices = tile_m.index
54+
n_indices = tile_n.index
55+
56+
# Create masks for valid elements
57+
m_valid = m_indices < M
58+
n_valid = n_indices < N
59+
60+
# Calculate global indices
61+
m_indices_valid = torch.where(m_valid, m_indices, 0)
62+
n_indices_valid = torch.where(n_valid, n_indices, 0)
63+
64+
# Global indices in concatenated tensors
65+
global_m = M_start + m_indices_valid
66+
global_n = N_start + n_indices_valid
67+
68+
# Initialize accumulator
69+
acc = hl.zeros([tile_m, tile_n], dtype=torch.float32)
70+
71+
# Accumulate over K dimension
72+
for tile_k in hl.tile(K):
73+
k_indices = tile_k.index
74+
75+
# Load tiles from concatenated tensors
76+
A_tile = A_concat[global_m, k_indices]
77+
B_tile = B_concat[k_indices, global_n]
78+
79+
# Accumulate
80+
acc = torch.addmm(acc, A_tile, B_tile)
81+
82+
# Write back to output with masking
83+
block_m = acc.size(0)
84+
block_n = acc.size(1)
85+
86+
# Get existing values
87+
existing_values = C_concat[global_m, global_n]
88+
89+
# Create 2D mask for output
90+
mask_2d = m_valid.view(block_m, 1).expand(block_m, block_n) & n_valid.view(1, block_n).expand(block_m, block_n)
91+
92+
# Write results only for valid positions
93+
C_concat[global_m, global_n] = torch.where(
94+
mask_2d, acc.to(C_concat.dtype), existing_values
95+
)
96+
97+
return C_concat
98+
99+
100+
def grouped_gemm_helion_kernel_args_gen(
101+
group_A: list[torch.Tensor],
102+
group_B: list[torch.Tensor]
103+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
104+
"""Generate arguments for the Helion kernel by concatenating inputs"""
105+
device = group_A[0].device
106+
dtype = group_A[0].dtype
107+
G = len(group_A)
108+
109+
# Check that all matrices have the same K dimension
110+
K = group_A[0].shape[1]
111+
for i in range(G):
112+
assert group_A[i].shape[1] == K, f"All A matrices must have same K dimension"
113+
assert group_B[i].shape[0] == K, f"All B matrices must have K dimension matching A"
114+
115+
# Get sizes for each GEMM
116+
Ms = [A.shape[0] for A in group_A]
117+
Ns = [B.shape[1] for B in group_B]
118+
119+
# Find maximum dimensions
120+
max_M = max(Ms)
121+
max_N = max(Ns)
122+
123+
# Calculate offsets
124+
M_offsets = [0]
125+
N_offsets = [0]
126+
for i in range(G):
127+
M_offsets.append(M_offsets[-1] + Ms[i])
128+
N_offsets.append(N_offsets[-1] + Ns[i])
129+
130+
# Concatenate tensors
131+
A_concat = torch.cat(group_A, dim=0) # [total_M, K]
132+
B_concat = torch.cat(group_B, dim=1) # [K, total_N]
133+
134+
# Create size and offset tensors
135+
group_sizes_M = torch.tensor(Ms, dtype=torch.int32, device=device)
136+
group_sizes_N = torch.tensor(Ns, dtype=torch.int32, device=device)
137+
group_offsets_M = torch.tensor(M_offsets, dtype=torch.int32, device=device)
138+
group_offsets_N = torch.tensor(N_offsets, dtype=torch.int32, device=device)
139+
140+
# Create dummy tensors to pass dimensions
141+
max_M_tensor = torch.empty(max_M, device=device)
142+
max_N_tensor = torch.empty(max_N, device=device)
143+
144+
return (A_concat, B_concat, group_sizes_M, group_sizes_N,
145+
group_offsets_M, group_offsets_N, max_M_tensor, max_N_tensor)
146+
147+
148+
def split_output(C_concat: torch.Tensor, group_sizes_M: list[int], group_sizes_N: list[int]) -> list[torch.Tensor]:
149+
"""Split concatenated output back into individual matrices"""
150+
outputs = []
151+
M_offset = 0
152+
N_offset = 0
153+
154+
for M, N in zip(group_sizes_M, group_sizes_N):
155+
C = C_concat[M_offset:M_offset+M, N_offset:N_offset+N]
156+
outputs.append(C)
157+
M_offset += M
158+
N_offset += N
159+
160+
return outputs
161+
162+
163+
def grouped_gemm_tritonbench(group_A: list[torch.Tensor], group_B: list[torch.Tensor]) -> list[torch.Tensor]:
164+
"""Wrapper function for tritonbench compatibility"""
165+
# Use the concatenated approach for better performance
166+
kernel_args = grouped_gemm_helion_kernel_args_gen(group_A, group_B)
167+
C_concat = grouped_gemm(*kernel_args)
168+
169+
# Split output back into individual matrices
170+
Ms = [A.shape[0] for A in group_A]
171+
Ns = [B.shape[1] for B in group_B]
172+
return split_output(C_concat, Ms, Ns)
173+
174+
175+
def grouped_gemm_pytorch(group_A: list[torch.Tensor], group_B: list[torch.Tensor]) -> list[torch.Tensor]:
176+
"""Reference PyTorch implementation"""
177+
outputs = []
178+
for A, B in zip(group_A, group_B):
179+
C = torch.matmul(A, B)
180+
outputs.append(C)
181+
return outputs
182+
183+
184+
def check(group_size: int = 4, base_size: int = 256) -> None:
185+
"""Test the grouped GEMM implementation"""
186+
dtype = torch.float16
187+
device = "cuda" if torch.cuda.is_available() else "cpu"
188+
189+
# Create test data with varying sizes
190+
group_A = []
191+
group_B = []
192+
193+
for i in range(group_size):
194+
# Vary sizes for each GEMM to test handling of different dimensions
195+
M = base_size + i * 64
196+
N = base_size + (i + 1) * 32
197+
K = base_size # Keep K constant for concatenation
198+
199+
A = torch.randn(M, K, device=device, dtype=dtype)
200+
B = torch.randn(K, N, device=device, dtype=dtype)
201+
202+
group_A.append(A)
203+
group_B.append(B)
204+
205+
# Test the concatenated kernel
206+
kernel_args = grouped_gemm_helion_kernel_args_gen(group_A, group_B)
207+
208+
def helion_fn() -> torch.Tensor:
209+
return grouped_gemm(*kernel_args)
210+
211+
def reference_fn() -> torch.Tensor:
212+
# Create reference output in concatenated form
213+
C_list = grouped_gemm_pytorch(group_A, group_B)
214+
215+
# Concatenate in block-diagonal form
216+
total_M = sum(A.shape[0] for A in group_A)
217+
total_N = sum(B.shape[1] for B in group_B)
218+
C_concat = torch.zeros(total_M, total_N, device=device, dtype=torch.promote_types(group_A[0].dtype, group_B[0].dtype))
219+
220+
M_offset = 0
221+
N_offset = 0
222+
for C in C_list:
223+
M, N = C.shape
224+
C_concat[M_offset:M_offset+M, N_offset:N_offset+N] = C
225+
M_offset += M
226+
N_offset += N
227+
228+
return C_concat
229+
230+
# Compare outputs
231+
run_example(helion_fn, reference_fn, ())
232+
233+
234+
def main() -> None:
235+
# Test with different configurations
236+
print("Testing grouped GEMM with group_size=4, base_size=256")
237+
check(group_size=4, base_size=256)
238+
239+
print("\nTesting grouped GEMM with group_size=8, base_size=128")
240+
check(group_size=8, base_size=128)
241+
242+
print("\nTesting grouped GEMM with group_size=2, base_size=512")
243+
check(group_size=2, base_size=512)
244+
245+
246+
if __name__ == "__main__":
247+
main()

0 commit comments

Comments
 (0)