Skip to content

Commit 2f00646

Browse files
committed
[Benchmark] Add jagged_mean tritonbench integration
stack-info: PR: #264, branch: yf225/stack/10
1 parent b3e504a commit 2f00646

File tree

4 files changed

+65
-25
lines changed

4 files changed

+65
-25
lines changed

benchmarks/run.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@
2020
from typing import Callable
2121

2222
# Maps tritonbench op names to Helion kernel examples
23-
KERNEL_MAPPINGS: dict[str, tuple[str, str] | tuple[str, str, dict[str, Any]]] = {
24-
# <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>, <optional_extra_args>)
23+
KERNEL_MAPPINGS: dict[str, tuple[str, str]] = {
24+
# <tritonbench_op_name>: (<helion_kernel_module_path>, <helion_kernel_function_name>)
2525
"vector_add": ("examples.add", "add"),
2626
"embedding": ("examples.embedding", "embedding_tritonbench"),
2727
"vector_exp": ("examples.exp", "exp_tritonbench"),
28-
# TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg.
29-
"rms_norm": ("examples.rms_norm", "rms_norm_tritonbench", {"num_inputs": 3}),
28+
"rms_norm": ("examples.rms_norm", "rms_norm_tritonbench"),
3029
"sum": ("examples.sum", "sum_tritonbench"),
30+
"jagged_mean": ("examples.jagged_mean", "jagged_mean_tritonbench"),
3131
}
3232

3333

@@ -168,14 +168,8 @@ def main() -> None:
168168

169169
# Check if kernel is in the mapping table
170170
assert kernel_name in KERNEL_MAPPINGS
171-
mapping = KERNEL_MAPPINGS[kernel_name]
172-
173-
# Parse mapping - can be (module, func) or (module, func, extra_args)
174-
if len(mapping) == 2:
175-
module_path, func_name = mapping
176-
kernel_extra_args = {}
177-
else:
178-
module_path, func_name, kernel_extra_args = mapping
171+
module_path, func_name = KERNEL_MAPPINGS[kernel_name]
172+
179173
# Import from the mapped module
180174
try:
181175
module = importlib.import_module(module_path)
@@ -213,14 +207,17 @@ def main() -> None:
213207
assert "--op" not in tritonbench_args
214208
tritonbench_args = ["--op", operator_name, *tritonbench_args]
215209

216-
# Apply kernel-specific default arguments if not already specified by user
217-
for arg_name, arg_value in kernel_extra_args.items():
218-
# Convert underscore to hyphen for CLI args (e.g., num_inputs -> --num-inputs)
219-
cli_arg = f"--{arg_name.replace('_', '-')}"
220-
if cli_arg not in tritonbench_args:
221-
tritonbench_args.extend([cli_arg, str(arg_value)])
210+
# Get module's TRITONBENCH_ARGS if any
211+
module_args = getattr(module, "TRITONBENCH_ARGS", {})
212+
213+
# Add module args to tritonbench_args if not already present
214+
for arg_name, arg_value in module_args.items():
215+
arg_flag = f"--{arg_name.replace('_', '-')}"
216+
if arg_flag not in tritonbench_args:
217+
tritonbench_args.extend([arg_flag, str(arg_value)])
222218

223-
tb_args = tb_parser.parse_args(tritonbench_args)
219+
# Parse known args and collect unknown ones for operator
220+
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)
224221

225222
# Register the Helion kernel with tritonbench BEFORE importing the operator
226223
from tritonbench.utils.triton_op import ( # pyre-ignore[21]
@@ -284,12 +281,12 @@ def _inner() -> Callable[..., Any]: # pyre-ignore[3]
284281
file=sys.stderr,
285282
)
286283

287-
# Create and run the operator
288-
op = Operator(tb_args=tb_args, extra_args={})
284+
# Create and run the operator with unknown args
285+
op = Operator(tb_args=tb_args, extra_args=unknown_args)
289286

290287
# Run with proper parameters
291-
warmup = getattr(tb_args, "warmup", 25)
292-
rep = getattr(tb_args, "iter", 100)
288+
warmup = int(getattr(tb_args, "warmup", 25))
289+
rep = int(getattr(tb_args, "iter", 100))
293290
op.run(warmup=warmup, rep=rep)
294291

295292
# Print results

examples/jagged_mean.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
import helion
66
from helion._testing import run_example
77
import helion.language as hl
8+
from helion.utils import get_gpu_memory_info
9+
10+
# TritonBench configuration - adjust based on available GPU memory
11+
if get_gpu_memory_info()[0] < 16.0:
12+
# Low memory configuration
13+
TRITONBENCH_ARGS = {"B": 32, "M": 8, "seqlen": 64}
814

915

1016
@helion.kernel()
@@ -123,8 +129,6 @@ def jagged_mean_tritonbench(
123129
Returns:
124130
Tensor of shape (B, M) with mean values per row and feature
125131
"""
126-
assert isinstance(x, NestedTensor), f"Input x must be a NestedTensor, got {type(x)}"
127-
128132
x_values = x._values
129133
x_offsets = x._offsets # pyre-ignore[16]
130134

examples/rms_norm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
from helion._testing import run_example
77
import helion.language as hl
88

9+
# TritonBench configuration
10+
# TODO(yf225): reduction dim size = 8192 currently throws error. After it's fixed we can remove "num_inputs" extra arg.
11+
TRITONBENCH_ARGS = {"num_inputs": 3}
12+
913

1014
@helion.kernel(static_shapes=True)
1115
def rms_norm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:

helion/utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
6+
def get_gpu_memory_info(device_id: int | None = None) -> tuple[float, float]:
7+
"""
8+
Get total and available GPU memory in GB.
9+
10+
Args:
11+
device_id: GPU device ID. If None, uses current device.
12+
13+
Returns:
14+
Tuple of (total_memory_gb, available_memory_gb)
15+
"""
16+
if not torch.cuda.is_available():
17+
return (0.0, 0.0)
18+
19+
if device_id is None:
20+
device_id = torch.cuda.current_device()
21+
22+
# Get total memory
23+
total_memory = torch.cuda.get_device_properties(device_id).total_memory
24+
25+
# Get reserved memory (memory allocated by the caching allocator)
26+
reserved_memory = torch.cuda.memory_reserved(device_id)
27+
28+
# Available memory is approximately total - reserved
29+
available_memory = total_memory - reserved_memory
30+
31+
# Convert to GB
32+
total_gb = total_memory / (1024**3)
33+
available_gb = available_memory / (1024**3)
34+
35+
return (total_gb, available_gb)

0 commit comments

Comments
 (0)