1+ """
2+ Grouped GEMM - Multiple matrix multiplications in a single kernel launch
3+ """
4+
5+ from __future__ import annotations
6+
7+ import torch
8+
9+ import helion
10+ from helion ._testing import run_example
11+ import helion .language as hl
12+
13+
14+ @helion .kernel (static_shapes = False )
15+ def grouped_gemm (
16+ A_concat : torch .Tensor , # Concatenated A matrices [total_M, K]
17+ B_concat : torch .Tensor , # Concatenated B matrices [K, total_N]
18+ group_sizes_M : torch .Tensor , # [G] - M size for each GEMM
19+ group_sizes_N : torch .Tensor , # [G] - N size for each GEMM
20+ group_offsets_M : torch .Tensor , # [G+1] - Starting M offset for each GEMM
21+ group_offsets_N : torch .Tensor , # [G+1] - Starting N offset for each GEMM
22+ max_M_tensor : torch .Tensor , # Dummy tensor of size max(M)
23+ max_N_tensor : torch .Tensor , # Dummy tensor of size max(N)
24+ ) -> torch .Tensor : # [total_M, total_N] - Concatenated output
25+ """Grouped GEMM kernel using concatenated tensors"""
26+ G = group_sizes_M .shape [0 ]
27+ total_M , K = A_concat .shape
28+ _ , total_N = B_concat .shape
29+ max_M = max_M_tensor .numel ()
30+ max_N = max_N_tensor .numel ()
31+
32+ # Allocate output tensor
33+ C_concat = torch .zeros (
34+ total_M , total_N ,
35+ dtype = torch .promote_types (A_concat .dtype , B_concat .dtype ),
36+ device = A_concat .device
37+ )
38+
39+ # Process each GEMM
40+ for g_idx in hl .grid (G ):
41+ # Get dimensions and offsets for this GEMM
42+ M = group_sizes_M [g_idx ]
43+ N = group_sizes_N [g_idx ]
44+ M_start = group_offsets_M [g_idx ]
45+ N_start = group_offsets_N [g_idx ]
46+
47+ # Skip empty GEMMs
48+ valid_gemm = (M > 0 ) * (N > 0 ) # Use multiplication instead of 'and'
49+ if valid_gemm :
50+ # Tile over output dimensions
51+ for tile_m , tile_n in hl .tile ([max_M , max_N ]):
52+ # Get tile indices
53+ m_indices = tile_m .index
54+ n_indices = tile_n .index
55+
56+ # Create masks for valid elements
57+ m_valid = m_indices < M
58+ n_valid = n_indices < N
59+
60+ # Calculate global indices
61+ m_indices_valid = torch .where (m_valid , m_indices , 0 )
62+ n_indices_valid = torch .where (n_valid , n_indices , 0 )
63+
64+ # Global indices in concatenated tensors
65+ global_m = M_start + m_indices_valid
66+ global_n = N_start + n_indices_valid
67+
68+ # Initialize accumulator
69+ acc = hl .zeros ([tile_m , tile_n ], dtype = torch .float32 )
70+
71+ # Accumulate over K dimension
72+ for tile_k in hl .tile (K ):
73+ k_indices = tile_k .index
74+
75+ # Load tiles from concatenated tensors
76+ A_tile = A_concat [global_m , k_indices ]
77+ B_tile = B_concat [k_indices , global_n ]
78+
79+ # Accumulate
80+ acc = torch .addmm (acc , A_tile , B_tile )
81+
82+ # Write back to output with masking
83+ block_m = acc .size (0 )
84+ block_n = acc .size (1 )
85+
86+ # Get existing values
87+ existing_values = C_concat [global_m , global_n ]
88+
89+ # Create 2D mask for output
90+ mask_2d = m_valid .view (block_m , 1 ).expand (block_m , block_n ) & n_valid .view (1 , block_n ).expand (block_m , block_n )
91+
92+ # Write results only for valid positions
93+ C_concat [global_m , global_n ] = torch .where (
94+ mask_2d , acc .to (C_concat .dtype ), existing_values
95+ )
96+
97+ return C_concat
98+
99+
100+ def grouped_gemm_helion_kernel_args_gen (
101+ group_A : list [torch .Tensor ],
102+ group_B : list [torch .Tensor ]
103+ ) -> tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor ]:
104+ """Generate arguments for the Helion kernel by concatenating inputs"""
105+ device = group_A [0 ].device
106+ dtype = group_A [0 ].dtype
107+ G = len (group_A )
108+
109+ # Check that all matrices have the same K dimension
110+ K = group_A [0 ].shape [1 ]
111+ for i in range (G ):
112+ assert group_A [i ].shape [1 ] == K , f"All A matrices must have same K dimension"
113+ assert group_B [i ].shape [0 ] == K , f"All B matrices must have K dimension matching A"
114+
115+ # Get sizes for each GEMM
116+ Ms = [A .shape [0 ] for A in group_A ]
117+ Ns = [B .shape [1 ] for B in group_B ]
118+
119+ # Find maximum dimensions
120+ max_M = max (Ms )
121+ max_N = max (Ns )
122+
123+ # Calculate offsets
124+ M_offsets = [0 ]
125+ N_offsets = [0 ]
126+ for i in range (G ):
127+ M_offsets .append (M_offsets [- 1 ] + Ms [i ])
128+ N_offsets .append (N_offsets [- 1 ] + Ns [i ])
129+
130+ # Concatenate tensors
131+ A_concat = torch .cat (group_A , dim = 0 ) # [total_M, K]
132+ B_concat = torch .cat (group_B , dim = 1 ) # [K, total_N]
133+
134+ # Create size and offset tensors
135+ group_sizes_M = torch .tensor (Ms , dtype = torch .int32 , device = device )
136+ group_sizes_N = torch .tensor (Ns , dtype = torch .int32 , device = device )
137+ group_offsets_M = torch .tensor (M_offsets , dtype = torch .int32 , device = device )
138+ group_offsets_N = torch .tensor (N_offsets , dtype = torch .int32 , device = device )
139+
140+ # Create dummy tensors to pass dimensions
141+ max_M_tensor = torch .empty (max_M , device = device )
142+ max_N_tensor = torch .empty (max_N , device = device )
143+
144+ return (A_concat , B_concat , group_sizes_M , group_sizes_N ,
145+ group_offsets_M , group_offsets_N , max_M_tensor , max_N_tensor )
146+
147+
148+ def split_output (C_concat : torch .Tensor , group_sizes_M : list [int ], group_sizes_N : list [int ]) -> list [torch .Tensor ]:
149+ """Split concatenated output back into individual matrices"""
150+ outputs = []
151+ M_offset = 0
152+ N_offset = 0
153+
154+ for M , N in zip (group_sizes_M , group_sizes_N ):
155+ C = C_concat [M_offset :M_offset + M , N_offset :N_offset + N ]
156+ outputs .append (C )
157+ M_offset += M
158+ N_offset += N
159+
160+ return outputs
161+
162+
163+ def grouped_gemm_tritonbench (group_A : list [torch .Tensor ], group_B : list [torch .Tensor ]) -> list [torch .Tensor ]:
164+ """Wrapper function for tritonbench compatibility"""
165+ # Use the concatenated approach for better performance
166+ kernel_args = grouped_gemm_helion_kernel_args_gen (group_A , group_B )
167+ C_concat = grouped_gemm (* kernel_args )
168+
169+ # Split output back into individual matrices
170+ Ms = [A .shape [0 ] for A in group_A ]
171+ Ns = [B .shape [1 ] for B in group_B ]
172+ return split_output (C_concat , Ms , Ns )
173+
174+
175+ def grouped_gemm_pytorch (group_A : list [torch .Tensor ], group_B : list [torch .Tensor ]) -> list [torch .Tensor ]:
176+ """Reference PyTorch implementation"""
177+ outputs = []
178+ for A , B in zip (group_A , group_B ):
179+ C = torch .matmul (A , B )
180+ outputs .append (C )
181+ return outputs
182+
183+
184+ def check (group_size : int = 4 , base_size : int = 256 ) -> None :
185+ """Test the grouped GEMM implementation"""
186+ dtype = torch .float16
187+ device = "cuda" if torch .cuda .is_available () else "cpu"
188+
189+ # Create test data with varying sizes
190+ group_A = []
191+ group_B = []
192+
193+ for i in range (group_size ):
194+ # Vary sizes for each GEMM to test handling of different dimensions
195+ M = base_size + i * 64
196+ N = base_size + (i + 1 ) * 32
197+ K = base_size # Keep K constant for concatenation
198+
199+ A = torch .randn (M , K , device = device , dtype = dtype )
200+ B = torch .randn (K , N , device = device , dtype = dtype )
201+
202+ group_A .append (A )
203+ group_B .append (B )
204+
205+ # Test the concatenated kernel
206+ kernel_args = grouped_gemm_helion_kernel_args_gen (group_A , group_B )
207+
208+ def helion_fn () -> torch .Tensor :
209+ return grouped_gemm (* kernel_args )
210+
211+ def reference_fn () -> torch .Tensor :
212+ # Create reference output in concatenated form
213+ C_list = grouped_gemm_pytorch (group_A , group_B )
214+
215+ # Concatenate in block-diagonal form
216+ total_M = sum (A .shape [0 ] for A in group_A )
217+ total_N = sum (B .shape [1 ] for B in group_B )
218+ C_concat = torch .zeros (total_M , total_N , device = device , dtype = torch .promote_types (group_A [0 ].dtype , group_B [0 ].dtype ))
219+
220+ M_offset = 0
221+ N_offset = 0
222+ for C in C_list :
223+ M , N = C .shape
224+ C_concat [M_offset :M_offset + M , N_offset :N_offset + N ] = C
225+ M_offset += M
226+ N_offset += N
227+
228+ return C_concat
229+
230+ # Compare outputs
231+ run_example (helion_fn , reference_fn , ())
232+
233+
234+ def main () -> None :
235+ # Test with different configurations
236+ print ("Testing grouped GEMM with group_size=4, base_size=256" )
237+ check (group_size = 4 , base_size = 256 )
238+
239+ print ("\n Testing grouped GEMM with group_size=8, base_size=128" )
240+ check (group_size = 8 , base_size = 128 )
241+
242+ print ("\n Testing grouped GEMM with group_size=2, base_size=512" )
243+ check (group_size = 2 , base_size = 512 )
244+
245+
246+ if __name__ == "__main__" :
247+ main ()
0 commit comments