2626from typing import Callable
2727
2828# Maps tritonbench op names to Helion kernel examples
29- KERNEL_MAPPINGS : dict [str , tuple [str , str , str ]] = {
29+ # Can map to a single kernel or a list of kernel variants
30+ KERNEL_MAPPINGS : dict [str , tuple [str , str , str ] | tuple [str , list [tuple [str , str ]]]] = {
3031 # <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
3132 "vector_add" : ("tritonbench.operators.vector_add.operator" , "examples.add" , "add" ),
3233 "embedding" : (
8081 "examples.layer_norm" ,
8182 "layer_norm_fwd" ,
8283 ),
84+ # Multiple kernel variants:
85+ "gemm" : (
86+ "tritonbench.operators.gemm.operator" ,
87+ [
88+ ("examples.matmul" , "matmul" ),
89+ ("examples.matmul_split_k" , "matmul_split_k" ),
90+ ],
91+ ),
8392}
8493
8594
@@ -210,7 +219,7 @@ def run_kernel(
210219 tritonbench_args : list [str ],
211220 input_shard_info : tuple [int , int ] | None = None ,
212221) -> None :
213- """Run a single kernel benchmark."""
222+ """Run a kernel benchmark, handling both single and multiple variants ."""
214223 # Check if kernel is in the mapping table
215224 if kernel_name not in KERNEL_MAPPINGS :
216225 print (f"Error: Unknown kernel '{ kernel_name } '" , file = sys .stderr )
@@ -219,25 +228,33 @@ def run_kernel(
219228 )
220229 sys .exit (1 )
221230
222- tritonbench_module , module_path , func_name = KERNEL_MAPPINGS [kernel_name ]
231+ mapping = KERNEL_MAPPINGS [kernel_name ]
232+
233+ # Normalize to list of variants format
234+ if len (mapping ) == 2 and isinstance (mapping [1 ], list ):
235+ # Multiple variants with shared tritonbench module
236+ tritonbench_module = mapping [0 ]
237+ variants = mapping [1 ]
238+ else :
239+ # Single kernel with full mapping - convert to list format
240+ assert len (mapping ) == 3 # Type narrowing for pyright
241+ tritonbench_module , module_path , func_name = mapping
242+ variants = [(module_path , func_name )]
243+
244+ # Run all variants in the same benchmark
245+ run_kernel_variants (
246+ kernel_name , tritonbench_module , variants , tritonbench_args , input_shard_info
247+ )
223248
224- # Import from the mapped module
225- try :
226- module = importlib .import_module (module_path )
227- if not hasattr (module , func_name ):
228- print (
229- f"Error: Module '{ module_path } ' does not have a function named '{ func_name } '" ,
230- file = sys .stderr ,
231- )
232- sys .exit (1 )
233- kernel_func = getattr (module , func_name )
234- except ImportError as e :
235- print (
236- f"Error: Could not import { func_name } from { module_path } " , file = sys .stderr
237- )
238- print (f"Import error: { e } " , file = sys .stderr )
239- sys .exit (1 )
240- return
249+
250+ def run_kernel_variants (
251+ kernel_name : str ,
252+ tritonbench_module : str ,
253+ variants : list [tuple [str , str ]],
254+ tritonbench_args : list [str ],
255+ input_shard_info : tuple [int , int ] | None = None ,
256+ ) -> None :
257+ """Run kernel variants in the same benchmark run."""
241258
242259 # Import tritonbench components
243260 try :
@@ -260,19 +277,26 @@ def run_kernel(
260277 assert "--op" not in tritonbench_args
261278 tritonbench_args = ["--op" , operator_name , * tritonbench_args ]
262279
263- # Get module's TRITONBENCH_ARGS if any
264- module_args = getattr (module , "TRITONBENCH_ARGS" , {})
280+ # Collect all module args from all variants
281+ all_module_args = {}
282+ for module_path , _ in variants :
283+ try :
284+ module = importlib .import_module (module_path )
285+ module_args = getattr (module , "TRITONBENCH_ARGS" , {})
286+ all_module_args .update (module_args )
287+ except ImportError :
288+ pass
265289
266290 # Add module args to tritonbench_args if not already present
267- for arg_name , arg_value in module_args .items ():
291+ for arg_name , arg_value in all_module_args .items ():
268292 arg_flag = f"--{ arg_name .replace ('_' , '-' )} "
269293 if arg_flag not in tritonbench_args :
270294 tritonbench_args .extend ([arg_flag , str (arg_value )])
271295
272296 # Parse known args and collect unknown ones for operator
273297 tb_args , unknown_args = tb_parser .parse_known_args (tritonbench_args )
274298
275- # Import and run the operator
299+ # Import and get the operator class
276300 try :
277301 operator_module = importlib .import_module (tritonbench_module )
278302 Operator = operator_module .Operator
@@ -285,64 +309,97 @@ def run_kernel(
285309 print (f"Import error: { e } " , file = sys .stderr )
286310 sys .exit (1 )
287311
288- # Create the benchmark method
289- def helion_method (
290- self : object ,
291- * args : object ,
292- ) -> Callable [..., object ]:
293- """Helion implementation."""
294-
295- # Reset all Helion kernels before creating the benchmark function
296- # so that each input size can go through its own autotuning.
297- from helion .runtime .kernel import Kernel
298-
299- for attr_name in dir (module ):
300- attr = getattr (module , attr_name )
301- if isinstance (attr , Kernel ):
302- attr .reset ()
303-
304- def _inner () -> Callable [..., Any ] | object :
305- # Force autotuning unless HELION_USE_DEFAULT_CONFIG=1 is set
306- # This ensures we run autotuning even if the kernel has pre-specified configs
307- if os .environ .get ("HELION_USE_DEFAULT_CONFIG" , "0" ) != "1" :
308- # Find all Kernel objects in the module and force autotuning
309- for attr_name in dir (module ):
310- attr = getattr (module , attr_name )
311- if isinstance (attr , Kernel ):
312- attr .settings .force_autotune = True
313-
314- result = kernel_func (* args )
315- if callable (result ):
316- return result ()
317- return result
318-
319- return _inner
320-
321- # Method name for the benchmark
322- helion_method_name = f"helion_{ kernel_name } "
323-
324312 # Import register_benchmark API
325313 from tritonbench .utils .triton_op import ( # pyright: ignore[reportMissingImports]
326314 register_benchmark ,
327315 )
328316
329- # Use register_benchmark decorator
330- decorated_method = register_benchmark (
331- operator_name = operator_name ,
332- func_name = helion_method_name ,
333- baseline = False ,
334- enabled = True ,
335- fwd_only = False ,
336- label = helion_method_name ,
337- )(helion_method )
338-
339- # Set the decorated method on the Operator class
340- setattr (Operator , helion_method_name , decorated_method )
341-
342- print (
343- f"Running { operator_name } benchmark with Helion implementation...\n " ,
344- file = sys .stderr ,
345- )
317+ # Register all variants as separate methods
318+ for module_path , func_name in variants :
319+ # Import the kernel function
320+ try :
321+ module = importlib .import_module (module_path )
322+ if not hasattr (module , func_name ):
323+ print (
324+ f"Error: Module '{ module_path } ' does not have a function named '{ func_name } '" ,
325+ file = sys .stderr ,
326+ )
327+ continue
328+ kernel_func = getattr (module , func_name )
329+ except ImportError as e :
330+ print (
331+ f"Error: Could not import { func_name } from { module_path } " ,
332+ file = sys .stderr ,
333+ )
334+ print (f"Import error: { e } " , file = sys .stderr )
335+ continue
336+
337+ # Create the benchmark method closure to capture the correct module and function
338+ def create_helion_method (
339+ mod : Any , # noqa: ANN401
340+ kfunc : Callable [..., Any ],
341+ ) -> Callable [..., Any ]:
342+ def helion_method (
343+ self : object ,
344+ * args : object ,
345+ ) -> Callable [..., object ]:
346+ """Helion implementation."""
347+
348+ # Reset all Helion kernels before creating the benchmark function
349+ # so that each input size can go through its own autotuning.
350+ from helion .runtime .kernel import Kernel
351+
352+ for attr_name in dir (mod ):
353+ attr = getattr (mod , attr_name )
354+ if isinstance (attr , Kernel ):
355+ attr .reset ()
356+
357+ def _inner () -> Callable [..., Any ] | object :
358+ # Force autotuning unless HELION_USE_DEFAULT_CONFIG=1 is set
359+ # This ensures we run autotuning even if the kernel has pre-specified configs
360+ if os .environ .get ("HELION_USE_DEFAULT_CONFIG" , "0" ) != "1" :
361+ # Find all Kernel objects in the module and force autotuning
362+ for attr_name in dir (mod ):
363+ attr = getattr (mod , attr_name )
364+ if isinstance (attr , Kernel ):
365+ attr .settings .force_autotune = True
366+
367+ result = kfunc (* args )
368+ if callable (result ):
369+ return result ()
370+ return result
371+
372+ return _inner
373+
374+ return helion_method
375+
376+ # Method name for the benchmark
377+ variant_name = func_name
378+ helion_method_name = f"helion_{ variant_name } "
379+
380+ # Use register_benchmark decorator
381+ decorated_method = register_benchmark (
382+ operator_name = operator_name ,
383+ func_name = helion_method_name ,
384+ baseline = False ,
385+ enabled = True ,
386+ fwd_only = False ,
387+ label = helion_method_name ,
388+ )(create_helion_method (module , kernel_func ))
389+
390+ # Set the decorated method on the Operator class
391+ setattr (Operator , helion_method_name , decorated_method )
392+
393+ if len (variants ) == 1 :
394+ print (
395+ f"Running { operator_name } benchmark with Helion implementation...\n " ,
396+ file = sys .stderr ,
397+ )
398+ else :
399+ print (
400+ f"Running { operator_name } benchmark with { len (variants )} Helion implementations...\n " ,
401+ file = sys .stderr ,
402+ )
346403
347404 # Create and run the operator with unknown args
348405 op = Operator (tb_args = tb_args , extra_args = unknown_args )
0 commit comments