2626from typing import Callable
2727
2828# Maps tritonbench op names to Helion kernel examples
29- KERNEL_MAPPINGS : dict [str , tuple [str , str , str ]] = {
29+ # Can map to a single kernel or a list of kernel variants
30+ KERNEL_MAPPINGS : dict [str , tuple [str , str , str ] | list [tuple [str , str , str ]]] = {
3031 # <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
3132 "vector_add" : ("tritonbench.operators.vector_add.operator" , "examples.add" , "add" ),
3233 "embedding" : (
8081 "examples.layer_norm" ,
8182 "layer_norm_fwd" ,
8283 ),
84+ # Multiple kernel variants:
85+ "gemm" : (
86+ "tritonbench.operators.gemm.operator" ,
87+ [
88+ ("examples.matmul" , "matmul" ),
89+ ("examples.matmul_split_k" , "matmul_split_k" ),
90+ ],
91+ ),
8392}
8493
8594
@@ -210,7 +219,7 @@ def run_kernel(
210219 tritonbench_args : list [str ],
211220 input_shard_info : tuple [int , int ] | None = None ,
212221) -> None :
213- """Run a single kernel benchmark."""
222+ """Run a kernel benchmark, handling both single and multiple variants ."""
214223 # Check if kernel is in the mapping table
215224 if kernel_name not in KERNEL_MAPPINGS :
216225 print (f"Error: Unknown kernel '{ kernel_name } '" , file = sys .stderr )
@@ -219,25 +228,32 @@ def run_kernel(
219228 )
220229 sys .exit (1 )
221230
222- tritonbench_module , module_path , func_name = KERNEL_MAPPINGS [kernel_name ]
231+ mapping = KERNEL_MAPPINGS [kernel_name ]
223232
224- # Import from the mapped module
225- try :
226- module = importlib .import_module (module_path )
227- if not hasattr (module , func_name ):
228- print (
229- f"Error: Module '{ module_path } ' does not have a function named '{ func_name } '" ,
230- file = sys .stderr ,
231- )
232- sys .exit (1 )
233- kernel_func = getattr (module , func_name )
234- except ImportError as e :
235- print (
236- f"Error: Could not import { func_name } from { module_path } " , file = sys .stderr
237- )
238- print (f"Import error: { e } " , file = sys .stderr )
239- sys .exit (1 )
240- return
233+ # Normalize to list of variants format
234+ if len (mapping ) == 2 and isinstance (mapping [1 ], list ):
235+ # Multiple variants with shared tritonbench module
236+ tritonbench_module = mapping [0 ]
237+ variants = mapping [1 ]
238+ else :
239+ # Single kernel with full mapping - convert to list format
240+ tritonbench_module , module_path , func_name = mapping
241+ variants = [(module_path , func_name )]
242+
243+ # Run all variants in the same benchmark
244+ run_kernel_variants (
245+ kernel_name , tritonbench_module , variants , tritonbench_args , input_shard_info
246+ )
247+
248+
249+ def run_kernel_variants (
250+ kernel_name : str ,
251+ tritonbench_module : str ,
252+ variants : list [tuple [str , str ]],
253+ tritonbench_args : list [str ],
254+ input_shard_info : tuple [int , int ] | None = None ,
255+ ) -> None :
256+ """Run kernel variants in the same benchmark run."""
241257
242258 # Import tritonbench components
243259 try :
@@ -260,19 +276,26 @@ def run_kernel(
260276 assert "--op" not in tritonbench_args
261277 tritonbench_args = ["--op" , operator_name , * tritonbench_args ]
262278
263- # Get module's TRITONBENCH_ARGS if any
264- module_args = getattr (module , "TRITONBENCH_ARGS" , {})
279+ # Collect all module args from all variants
280+ all_module_args = {}
281+ for module_path , func_name in variants :
282+ try :
283+ module = importlib .import_module (module_path )
284+ module_args = getattr (module , "TRITONBENCH_ARGS" , {})
285+ all_module_args .update (module_args )
286+ except ImportError :
287+ pass
265288
266289 # Add module args to tritonbench_args if not already present
267- for arg_name , arg_value in module_args .items ():
290+ for arg_name , arg_value in all_module_args .items ():
268291 arg_flag = f"--{ arg_name .replace ('_' , '-' )} "
269292 if arg_flag not in tritonbench_args :
270293 tritonbench_args .extend ([arg_flag , str (arg_value )])
271294
272295 # Parse known args and collect unknown ones for operator
273296 tb_args , unknown_args = tb_parser .parse_known_args (tritonbench_args )
274297
275- # Import and run the operator
298+ # Import and get the operator class
276299 try :
277300 operator_module = importlib .import_module (tritonbench_module )
278301 Operator = operator_module .Operator
@@ -285,64 +308,94 @@ def run_kernel(
285308 print (f"Import error: { e } " , file = sys .stderr )
286309 sys .exit (1 )
287310
288- # Create the benchmark method
289- def helion_method (
290- self : object ,
291- * args : object ,
292- ) -> Callable [..., object ]:
293- """Helion implementation."""
294-
295- # Reset all Helion kernels before creating the benchmark function
296- # so that each input size can go through its own autotuning.
297- from helion .runtime .kernel import Kernel
298-
299- for attr_name in dir (module ):
300- attr = getattr (module , attr_name )
301- if isinstance (attr , Kernel ):
302- attr .reset ()
303-
304- def _inner () -> Callable [..., Any ] | object :
305- # Force autotuning unless HELION_USE_DEFAULT_CONFIG=1 is set
306- # This ensures we run autotuning even if the kernel has pre-specified configs
307- if os .environ .get ("HELION_USE_DEFAULT_CONFIG" , "0" ) != "1" :
308- # Find all Kernel objects in the module and force autotuning
309- for attr_name in dir (module ):
310- attr = getattr (module , attr_name )
311- if isinstance (attr , Kernel ):
312- attr .settings .force_autotune = True
313-
314- result = kernel_func (* args )
315- if callable (result ):
316- return result ()
317- return result
318-
319- return _inner
320-
321- # Method name for the benchmark
322- helion_method_name = f"helion_{ kernel_name } "
323-
324311 # Import register_benchmark API
325312 from tritonbench .utils .triton_op import ( # pyright: ignore[reportMissingImports]
326313 register_benchmark ,
327314 )
328315
329- # Use register_benchmark decorator
330- decorated_method = register_benchmark (
331- operator_name = operator_name ,
332- func_name = helion_method_name ,
333- baseline = False ,
334- enabled = True ,
335- fwd_only = False ,
336- label = helion_method_name ,
337- )(helion_method )
338-
339- # Set the decorated method on the Operator class
340- setattr (Operator , helion_method_name , decorated_method )
341-
342- print (
343- f"Running { operator_name } benchmark with Helion implementation...\n " ,
344- file = sys .stderr ,
345- )
316+ # Register all variants as separate methods
317+ for module_path , func_name in variants :
318+ # Import the kernel function
319+ try :
320+ module = importlib .import_module (module_path )
321+ if not hasattr (module , func_name ):
322+ print (
323+ f"Error: Module '{ module_path } ' does not have a function named '{ func_name } '" ,
324+ file = sys .stderr ,
325+ )
326+ continue
327+ kernel_func = getattr (module , func_name )
328+ except ImportError as e :
329+ print (
330+ f"Error: Could not import { func_name } from { module_path } " ,
331+ file = sys .stderr ,
332+ )
333+ print (f"Import error: { e } " , file = sys .stderr )
334+ continue
335+
336+ # Create the benchmark method closure to capture the correct module and function
337+ def create_helion_method (mod , kfunc ):
338+ def helion_method (
339+ self : object ,
340+ * args : object ,
341+ ) -> Callable [..., object ]:
342+ """Helion implementation."""
343+
344+ # Reset all Helion kernels before creating the benchmark function
345+ # so that each input size can go through its own autotuning.
346+ from helion .runtime .kernel import Kernel
347+
348+ for attr_name in dir (mod ):
349+ attr = getattr (mod , attr_name )
350+ if isinstance (attr , Kernel ):
351+ attr .reset ()
352+
353+ def _inner () -> Callable [..., Any ] | object :
354+ # Force autotuning unless HELION_USE_DEFAULT_CONFIG=1 is set
355+ # This ensures we run autotuning even if the kernel has pre-specified configs
356+ if os .environ .get ("HELION_USE_DEFAULT_CONFIG" , "0" ) != "1" :
357+ # Find all Kernel objects in the module and force autotuning
358+ for attr_name in dir (mod ):
359+ attr = getattr (mod , attr_name )
360+ if isinstance (attr , Kernel ):
361+ attr .settings .force_autotune = True
362+
363+ result = kfunc (* args )
364+ if callable (result ):
365+ return result ()
366+ return result
367+
368+ return _inner
369+
370+ return helion_method
371+
372+ # Method name for the benchmark
373+ variant_name = func_name
374+ helion_method_name = f"helion_{ variant_name } "
375+
376+ # Use register_benchmark decorator
377+ decorated_method = register_benchmark (
378+ operator_name = operator_name ,
379+ func_name = helion_method_name ,
380+ baseline = False ,
381+ enabled = True ,
382+ fwd_only = False ,
383+ label = helion_method_name ,
384+ )(create_helion_method (module , kernel_func ))
385+
386+ # Set the decorated method on the Operator class
387+ setattr (Operator , helion_method_name , decorated_method )
388+
389+ if len (variants ) == 1 :
390+ print (
391+ f"Running { operator_name } benchmark with Helion implementation...\n " ,
392+ file = sys .stderr ,
393+ )
394+ else :
395+ print (
396+ f"Running { operator_name } benchmark with { len (variants )} Helion implementations...\n " ,
397+ file = sys .stderr ,
398+ )
346399
347400 # Create and run the operator with unknown args
348401 op = Operator (tb_args = tb_args , extra_args = unknown_args )
0 commit comments