diff --git a/benchmarks/run.py b/benchmarks/run.py index c9731ee01..6b755840d 100644 --- a/benchmarks/run.py +++ b/benchmarks/run.py @@ -20,14 +20,14 @@ from typing import Callable # Maps tritonbench op names to Helion kernel examples -KERNEL_MAPPINGS: dict[str, tuple[str, str] | tuple[str, str, dict[str, Any]]] = { - # : (, , ) +KERNEL_MAPPINGS: dict[str, tuple[str, str]] = { + # : (, ) "vector_add": ("examples.add", "add"), "embedding": ("examples.embedding", "embedding_tritonbench"), "vector_exp": ("examples.exp", "exp_tritonbench"), - # TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg. - "rms_norm": ("examples.rms_norm", "rms_norm_tritonbench", {"num_inputs": 3}), + "rms_norm": ("examples.rms_norm", "rms_norm_tritonbench"), "sum": ("examples.sum", "sum_tritonbench"), + "jagged_mean": ("examples.jagged_mean", "jagged_mean_tritonbench"), } @@ -168,14 +168,8 @@ def main() -> None: # Check if kernel is in the mapping table assert kernel_name in KERNEL_MAPPINGS - mapping = KERNEL_MAPPINGS[kernel_name] - - # Parse mapping - can be (module, func) or (module, func, extra_args) - if len(mapping) == 2: - module_path, func_name = mapping - kernel_extra_args = {} - else: - module_path, func_name, kernel_extra_args = mapping + module_path, func_name = KERNEL_MAPPINGS[kernel_name] + # Import from the mapped module try: module = importlib.import_module(module_path) @@ -215,14 +209,17 @@ def main() -> None: assert "--op" not in tritonbench_args tritonbench_args = ["--op", operator_name, *tritonbench_args] - # Apply kernel-specific default arguments if not already specified by user - for arg_name, arg_value in kernel_extra_args.items(): - # Convert underscore to hyphen for CLI args (e.g., num_inputs -> --num-inputs) - cli_arg = f"--{arg_name.replace('_', '-')}" - if cli_arg not in tritonbench_args: - tritonbench_args.extend([cli_arg, str(arg_value)]) + # Get module's TRITONBENCH_ARGS if any + module_args = getattr(module, "TRITONBENCH_ARGS", {}) + + # Add module args to tritonbench_args if not already present + for arg_name, arg_value in module_args.items(): + arg_flag = f"--{arg_name.replace('_', '-')}" + if arg_flag not in tritonbench_args: + tritonbench_args.extend([arg_flag, str(arg_value)]) - tb_args = tb_parser.parse_args(tritonbench_args) + # Parse known args and collect unknown ones for operator + tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args) # Register the Helion kernel with tritonbench BEFORE importing the operator from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports] @@ -286,12 +283,12 @@ def _inner() -> Callable[..., Any]: file=sys.stderr, ) - # Create and run the operator - op = Operator(tb_args=tb_args, extra_args={}) + # Create and run the operator with unknown args + op = Operator(tb_args=tb_args, extra_args=unknown_args) # Run with proper parameters - warmup = getattr(tb_args, "warmup", 25) - rep = getattr(tb_args, "iter", 100) + warmup = int(getattr(tb_args, "warmup", 25)) + rep = int(getattr(tb_args, "iter", 100)) op.run(warmup=warmup, rep=rep) # Print results diff --git a/examples/jagged_mean.py b/examples/jagged_mean.py index 710abd5d8..540865b14 100644 --- a/examples/jagged_mean.py +++ b/examples/jagged_mean.py @@ -5,6 +5,12 @@ import helion from helion._testing import run_example import helion.language as hl +from helion.utils import get_gpu_memory_info + +# TritonBench configuration - adjust based on available GPU memory +if get_gpu_memory_info()[0] < 16.0: + # Low memory configuration + TRITONBENCH_ARGS = {"B": 32, "M": 8, "seqlen": 64} @helion.kernel() diff --git a/examples/rms_norm.py b/examples/rms_norm.py index 554332826..c1b468410 100644 --- a/examples/rms_norm.py +++ b/examples/rms_norm.py @@ -6,6 +6,10 @@ from helion._testing import run_example import helion.language as hl +# TritonBench configuration +# TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg. +TRITONBENCH_ARGS = {"num_inputs": 3} + @helion.kernel(static_shapes=True) def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor: diff --git a/helion/utils.py b/helion/utils.py new file mode 100644 index 000000000..0e6f91774 --- /dev/null +++ b/helion/utils.py @@ -0,0 +1,35 @@ +from __future__ import annotations + +import torch + + +def get_gpu_memory_info(device_id: int | None = None) -> tuple[float, float]: + """ + Get total and available GPU memory in GB. + + Args: + device_id: GPU device ID. If None, uses current device. + + Returns: + Tuple of (total_memory_gb, available_memory_gb) + """ + if not torch.cuda.is_available(): + return (0.0, 0.0) + + if device_id is None: + device_id = torch.cuda.current_device() + + # Get total memory + total_memory = torch.cuda.get_device_properties(device_id).total_memory + + # Get reserved memory (memory allocated by the caching allocator) + reserved_memory = torch.cuda.memory_reserved(device_id) + + # Available memory is approximately total - reserved + available_memory = total_memory - reserved_memory + + # Convert to GB + total_gb = total_memory / (1024**3) + available_gb = available_memory / (1024**3) + + return (total_gb, available_gb)