Skip to content

Commit 20ea010

Browse files
committed
[Benchmark] Add jagged_mean tritonbench integration
stack-info: PR: #264, branch: yf225/stack/10
1 parent 48cfdec commit 20ea010

File tree

4 files changed

+65
-23
lines changed

4 files changed

+65
-23
lines changed

benchmarks/run.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
from typing import Callable
2121

2222
# Maps tritonbench op names to Helion kernel examples
23-
KERNEL_MAPPINGS: dict[str, tuple[str, str] | tuple[str, str, dict[str, Any]]] = {
24-
# <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>, <optional_extra_args>)
23+
KERNEL_MAPPINGS: dict[str, tuple[str, str]] = {
24+
# <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>)
2525
"vector_add": ("examples.add", "add"),
2626
"embedding": ("examples.embedding", "embedding_tritonbench"),
2727
"vector_exp": ("examples.exp", "exp_tritonbench"),
28-
# TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg.
29-
"rms_norm": ("examples.rms_norm", "rms_norm_tritonbench", {"num_inputs": 3}),
28+
"rms_norm": ("examples.rms_norm", "rms_norm_tritonbench"),
3029
"sum": ("examples.sum", "sum_tritonbench"),
30+
"jagged_mean": ("examples.jagged_mean", "jagged_mean_tritonbench"),
3131
}
3232

3333

@@ -168,14 +168,8 @@ def main() -> None:
168168

169169
# Check if kernel is in the mapping table
170170
assert kernel_name in KERNEL_MAPPINGS
171-
mapping = KERNEL_MAPPINGS[kernel_name]
172-
173-
# Parse mapping - can be (module, func) or (module, func, extra_args)
174-
if len(mapping) == 2:
175-
module_path, func_name = mapping
176-
kernel_extra_args = {}
177-
else:
178-
module_path, func_name, kernel_extra_args = mapping
171+
module_path, func_name = KERNEL_MAPPINGS[kernel_name]
172+
179173
# Import from the mapped module
180174
try:
181175
module = importlib.import_module(module_path)
@@ -215,14 +209,17 @@ def main() -> None:
215209
assert "--op" not in tritonbench_args
216210
tritonbench_args = ["--op", operator_name, *tritonbench_args]
217211

218-
# Apply kernel-specific default arguments if not already specified by user
219-
for arg_name, arg_value in kernel_extra_args.items():
220-
# Convert underscore to hyphen for CLI args (e.g., num_inputs -> --num-inputs)
221-
cli_arg = f"--{arg_name.replace('_', '-')}"
222-
if cli_arg not in tritonbench_args:
223-
tritonbench_args.extend([cli_arg, str(arg_value)])
212+
# Get module's TRITONBENCH_ARGS if any
213+
module_args = getattr(module, "TRITONBENCH_ARGS", {})
214+
215+
# Add module args to tritonbench_args if not already present
216+
for arg_name, arg_value in module_args.items():
217+
arg_flag = f"--{arg_name.replace('_', '-')}"
218+
if arg_flag not in tritonbench_args:
219+
tritonbench_args.extend([arg_flag, str(arg_value)])
224220

225-
tb_args = tb_parser.parse_args(tritonbench_args)
221+
# Parse known args and collect unknown ones for operator
222+
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)
226223

227224
# Register the Helion kernel with tritonbench BEFORE importing the operator
228225
from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports]
@@ -286,12 +283,12 @@ def _inner() -> Callable[..., Any]:
286283
file=sys.stderr,
287284
)
288285

289-
# Create and run the operator
290-
op = Operator(tb_args=tb_args, extra_args={})
286+
# Create and run the operator with unknown args
287+
op = Operator(tb_args=tb_args, extra_args=unknown_args)
291288

292289
# Run with proper parameters
293-
warmup = getattr(tb_args, "warmup", 25)
294-
rep = getattr(tb_args, "iter", 100)
290+
warmup = int(getattr(tb_args, "warmup", 25))
291+
rep = int(getattr(tb_args, "iter", 100))
295292
op.run(warmup=warmup, rep=rep)
296293

297294
# Print results

examples/jagged_mean.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
import helion
66
from helion._testing import run_example
77
import helion.language as hl
8+
from helion.utils import get_gpu_memory_info
9+
10+
# TritonBench configuration - adjust based on available GPU memory
11+
if get_gpu_memory_info()[0] < 16.0:
12+
# Low memory configuration
13+
TRITONBENCH_ARGS = {"B": 32, "M": 8, "seqlen": 64}
814

915

1016
@helion.kernel()

examples/rms_norm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from helion._testing import run_example
77
import helion.language as hl
88

9+
# TritonBench configuration
10+
# TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg.
11+
TRITONBENCH_ARGS = {"num_inputs": 3}
12+
913

1014
@helion.kernel(static_shapes=True)
1115
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:

helion/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
6+
def get_gpu_memory_info(device_id: int | None = None) -> tuple[float, float]:
7+
"""
8+
Get total and available GPU memory in GB.
9+
10+
Args:
11+
device_id: GPU device ID. If None, uses current device.
12+
13+
Returns:
14+
Tuple of (total_memory_gb, available_memory_gb)
15+
"""
16+
if not torch.cuda.is_available():
17+
return (0.0, 0.0)
18+
19+
if device_id is None:
20+
device_id = torch.cuda.current_device()
21+
22+
# Get total memory
23+
total_memory = torch.cuda.get_device_properties(device_id).total_memory
24+
25+
# Get reserved memory (memory allocated by the caching allocator)
26+
reserved_memory = torch.cuda.memory_reserved(device_id)
27+
28+
# Available memory is approximately total - reserved
29+
available_memory = total_memory - reserved_memory
30+
31+
# Convert to GB
32+
total_gb = total_memory / (1024**3)
33+
available_gb = available_memory / (1024**3)
34+
35+
return (total_gb, available_gb)

0 commit comments

Comments
 (0)