|
20 | 20 | from typing import Callable |
21 | 21 |
|
22 | 22 | # Maps tritonbench op names to Helion kernel examples |
23 | | -KERNEL_MAPPINGS: dict[str, tuple[str, str] | tuple[str, str, dict[str, Any]]] = { |
24 | | - # <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>, <optional_extra_args>) |
| 23 | +KERNEL_MAPPINGS: dict[str, tuple[str, str]] = { |
| 24 | + # <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>) |
25 | 25 | "vector_add": ("examples.add", "add"), |
26 | 26 | "embedding": ("examples.embedding", "embedding_tritonbench"), |
27 | 27 | "vector_exp": ("examples.exp", "exp_tritonbench"), |
28 | | - # TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg. |
29 | | - "rms_norm": ("examples.rms_norm", "rms_norm_tritonbench", {"num_inputs": 3}), |
| 28 | + "rms_norm": ("examples.rms_norm", "rms_norm_tritonbench"), |
30 | 29 | "sum": ("examples.sum", "sum_tritonbench"), |
| 30 | + "jagged_mean": ("examples.jagged_mean", "jagged_mean_tritonbench"), |
31 | 31 | } |
32 | 32 |
|
33 | 33 |
|
@@ -168,14 +168,8 @@ def main() -> None: |
168 | 168 |
|
169 | 169 | # Check if kernel is in the mapping table |
170 | 170 | assert kernel_name in KERNEL_MAPPINGS |
171 | | - mapping = KERNEL_MAPPINGS[kernel_name] |
172 | | - |
173 | | - # Parse mapping - can be (module, func) or (module, func, extra_args) |
174 | | - if len(mapping) == 2: |
175 | | - module_path, func_name = mapping |
176 | | - kernel_extra_args = {} |
177 | | - else: |
178 | | - module_path, func_name, kernel_extra_args = mapping |
| 171 | + module_path, func_name = KERNEL_MAPPINGS[kernel_name] |
| 172 | + |
179 | 173 | # Import from the mapped module |
180 | 174 | try: |
181 | 175 | module = importlib.import_module(module_path) |
@@ -215,14 +209,17 @@ def main() -> None: |
215 | 209 | assert "--op" not in tritonbench_args |
216 | 210 | tritonbench_args = ["--op", operator_name, *tritonbench_args] |
217 | 211 |
|
218 | | - # Apply kernel-specific default arguments if not already specified by user |
219 | | - for arg_name, arg_value in kernel_extra_args.items(): |
220 | | - # Convert underscore to hyphen for CLI args (e.g., num_inputs -> --num-inputs) |
221 | | - cli_arg = f"--{arg_name.replace('_', '-')}" |
222 | | - if cli_arg not in tritonbench_args: |
223 | | - tritonbench_args.extend([cli_arg, str(arg_value)]) |
| 212 | + # Get module's TRITONBENCH_ARGS if any |
| 213 | + module_args = getattr(module, "TRITONBENCH_ARGS", {}) |
| 214 | + |
| 215 | + # Add module args to tritonbench_args if not already present |
| 216 | + for arg_name, arg_value in module_args.items(): |
| 217 | + arg_flag = f"--{arg_name.replace('_', '-')}" |
| 218 | + if arg_flag not in tritonbench_args: |
| 219 | + tritonbench_args.extend([arg_flag, str(arg_value)]) |
224 | 220 |
|
225 | | - tb_args = tb_parser.parse_args(tritonbench_args) |
| 221 | + # Parse known args and collect unknown ones for operator |
| 222 | + tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args) |
226 | 223 |
|
227 | 224 | # Register the Helion kernel with tritonbench BEFORE importing the operator |
228 | 225 | from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports] |
@@ -286,12 +283,12 @@ def _inner() -> Callable[..., Any]: |
286 | 283 | file=sys.stderr, |
287 | 284 | ) |
288 | 285 |
|
289 | | - # Create and run the operator |
290 | | - op = Operator(tb_args=tb_args, extra_args={}) |
| 286 | + # Create and run the operator with unknown args |
| 287 | + op = Operator(tb_args=tb_args, extra_args=unknown_args) |
291 | 288 |
|
292 | 289 | # Run with proper parameters |
293 | | - warmup = getattr(tb_args, "warmup", 25) |
294 | | - rep = getattr(tb_args, "iter", 100) |
| 290 | + warmup = int(getattr(tb_args, "warmup", 25)) |
| 291 | + rep = int(getattr(tb_args, "iter", 100)) |
295 | 292 | op.run(warmup=warmup, rep=rep) |
296 | 293 |
|
297 | 294 | # Print results |
|
0 commit comments