Skip to content

Commit 65995f5

Browse files
authored
Allow passing tritonbench operator instance into kernel benchmark wrapper; Always return lambda for timing measurement (#596)
1 parent 1ac5365 commit 65995f5

File tree

12 files changed

+92
-54
lines changed

12 files changed

+92
-54
lines changed

benchmarks/run.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -393,16 +393,9 @@ def run_kernel_variants(
393393
"""Run kernel variants in the same benchmark run."""
394394

395395
# Import tritonbench components
396-
try:
397-
from tritonbench.utils.parser import ( # pyright: ignore[reportMissingImports]
398-
get_parser,
399-
)
400-
except ImportError:
401-
print(
402-
"Error: Could not import tritonbench. Make sure it's in the path.",
403-
file=sys.stderr,
404-
)
405-
sys.exit(1)
396+
from tritonbench.utils.parser import ( # pyright: ignore[reportMissingImports]
397+
get_parser,
398+
)
406399

407400
# Get the tritonbench operator name
408401
operator_name = kernel_name
@@ -500,14 +493,16 @@ def helion_method(
500493
attr.settings.force_autotune = True
501494
attr.settings.static_shape = True # pyright: ignore[reportAttributeAccessIssue]
502495

503-
def _inner() -> Callable[..., Any] | object:
504-
# BENCHMARK HOT PATH, do not add any new logic here
505-
result = kfunc(*args, **kwargs)
506-
if callable(result):
507-
return result()
508-
return result
496+
if isinstance(kfunc, Kernel):
497+
# Helion kernel - we call it in a lambda to delay execution until measurement
498+
measured_func_callable = lambda: kfunc(*args, **kwargs) # noqa: E731
499+
else:
500+
# tritonbench integration wrapper - pass tritonbench operator instance as first argument
501+
# The wrapper must return a callable that does the actual computation, for delayed execution
502+
measured_func_callable = kfunc(self, *args, **kwargs)
509503

510-
return _inner
504+
assert callable(measured_func_callable)
505+
return measured_func_callable
511506

512507
return helion_method
513508

examples/embedding.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# -------
1111
from __future__ import annotations
1212

13+
from typing import Callable
14+
1315
import torch
1416

1517
import helion
@@ -49,21 +51,22 @@ def embedding(x: torch.Tensor, weight: torch.Tensor) -> torch.Tensor:
4951
# Benchmark Wrapper
5052
# --------------
5153
def embedding_tritonbench(
52-
V: int, D: int, inp: torch.Tensor, shared_weight: torch.Tensor
53-
) -> torch.Tensor:
54+
tb_op: object, V: int, D: int, inp: torch.Tensor, shared_weight: torch.Tensor
55+
) -> Callable[[], torch.Tensor]:
5456
"""
5557
Wrapper for tritonbench that matches its interface.
5658
5759
Args:
60+
tb_op: TritonBench operator instance
5861
V: Vocabulary size (unused, provided for compatibility)
5962
D: Embedding dimension (unused, provided for compatibility)
6063
inp: Input tensor of indices
6164
shared_weight: Embedding weight matrix
6265
6366
Returns:
64-
Output tensor containing the embedding vectors
67+
Callable that returns output tensor containing the embedding vectors
6568
"""
66-
return embedding(inp, shared_weight)
69+
return lambda: embedding(inp, shared_weight)
6770

6871

6972
# %%

examples/exp.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# -------
1111
from __future__ import annotations
1212

13+
from typing import Callable
14+
1315
import torch
1416

1517
import helion
@@ -40,17 +42,20 @@ def exp(x: torch.Tensor) -> torch.Tensor:
4042
# %%
4143
# Benchmark Wrapper
4244
# --------------
43-
def exp_tritonbench(x: torch.Tensor) -> dict[str, torch.Tensor]:
45+
def exp_tritonbench(
46+
tb_op: object, x: torch.Tensor
47+
) -> Callable[[], dict[str, torch.Tensor]]:
4448
"""
4549
Wrapper for tritonbench that returns output in expected format.
4650
4751
Args:
52+
tb_op: TritonBench operator instance
4853
x: Input tensor
4954
5055
Returns:
51-
Dictionary containing the output tensor
56+
Callable that returns dictionary containing the output tensor
5257
"""
53-
return {"output": exp(x)}
58+
return lambda: {"output": exp(x)}
5459

5560

5661
# %%

examples/fp8_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,13 @@ def preprocess_fp8_attention_inputs(
135135

136136
# %%
137137
def fp8_attention_tritonbench(
138-
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
138+
tb_op: object, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor
139139
) -> Callable[[], torch.Tensor]:
140140
"""
141141
Creates a callable function for benchmarking FP8 attention with tritonbench.
142142
Preprocesses inputs and returns a lambda function that calls the FP8 attention kernel.
143143
Args:
144+
tb_op: TritonBench operator instance
144145
q: Query tensor of shape [batch, heads, seq_len, head_dim]
145146
k: Key tensor of shape [batch, heads, seq_len, head_dim]
146147
v: Value tensor of shape [batch, heads, seq_len, head_dim]
@@ -272,7 +273,7 @@ def check(batch: int, heads: int, seq_len: int, head_dim: int) -> None:
272273
v = torch.randn(batch, heads, seq_len, head_dim, dtype=torch.float16, device="cuda")
273274
from helion._testing import run_example
274275

275-
helion_fn = fp8_attention_tritonbench(q, k, v)
276+
helion_fn = fp8_attention_tritonbench(None, q, k, v)
276277
pytorch_fn = fp8_attention_pytorch(q, k, v)
277278
run_example(
278279
helion_fn,

examples/fp8_gemm.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from __future__ import annotations
1212

1313
import os
14+
from typing import Callable
1415

1516
import torch
1617

@@ -79,16 +80,25 @@ def reference_fp8_gemm_pytorch(
7980

8081

8182
# %%
82-
def fp8_gemm_tritonbench(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
83+
def fp8_gemm_tritonbench(
84+
tb_op: object,
85+
a: torch.Tensor,
86+
b: torch.Tensor,
87+
scale_a: torch.Tensor,
88+
scale_b: torch.Tensor,
89+
) -> Callable[[], torch.Tensor]:
8390
"""
8491
Wrapper for TritonBench compatibility.
8592
Args:
93+
tb_op: TritonBench operator instance
8694
a (torch.Tensor): Left input tensor in FP8 format.
8795
b (torch.Tensor): Right input tensor in FP8 format.
96+
scale_a (torch.Tensor): Scale factor for tensor a (unused in our implementation).
97+
scale_b (torch.Tensor): Scale factor for tensor b (unused in our implementation).
8898
Returns:
89-
torch.Tensor: Output tensor in FP16 format.
99+
Callable that returns output tensor in FP16 format.
90100
"""
91-
return fp8_gemm(a, b)
101+
return lambda: fp8_gemm(a, b)
92102

93103

94104
# %%

examples/jagged_hstu_attn.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
# -------
1111
from __future__ import annotations
1212

13+
from typing import Callable
14+
1315
import torch
1416

1517
import helion
@@ -143,15 +145,16 @@ def _helion_jagged_attention_kernel(
143145
# Benchmark Wrapper
144146
# --------------
145147
def ragged_attention_tritonbench(
148+
tb_op: object,
146149
q: torch.Tensor,
147150
k: torch.Tensor,
148151
v: torch.Tensor,
149152
seq_offsets: torch.Tensor,
150153
num_targets: torch.Tensor | None,
151154
max_seq_len: int,
152-
) -> torch.Tensor:
155+
) -> Callable[[], torch.Tensor]:
153156
"""Wrapper function for jagged attention kernel"""
154-
return _helion_jagged_attention_kernel(
157+
return lambda: _helion_jagged_attention_kernel(
155158
max_seq_len=max_seq_len,
156159
alpha=1.0 / v.size(2) ** 2,
157160
q=q,
@@ -246,7 +249,7 @@ def _triton_hstu_mha(
246249
baselines["tritonbench"] = _triton_hstu_mha
247250

248251
run_example(
249-
ragged_attention_tritonbench,
252+
lambda *args: ragged_attention_tritonbench(None, *args)(),
250253
baselines,
251254
(q, k, v, seq_offsets, None, max_seq_len),
252255
)

examples/jagged_mean.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# -------
1212
from __future__ import annotations
1313

14+
from typing import Callable
15+
1416
import torch
1517

1618
import helion
@@ -136,20 +138,21 @@ def reference_jagged_mean_kernel_pytorch(
136138
# Benchmark Wrapper
137139
# --------------
138140
def jagged_mean_tritonbench(
139-
x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
140-
) -> torch.Tensor:
141+
tb_op: object, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
142+
) -> Callable[[], torch.Tensor]:
141143
"""
142144
Wrapper for tritonbench that matches the expected interface.
143145
144146
Args:
147+
tb_op: TritonBench operator instance
145148
x: Nested tensor in jagged format with shape (B, *, M)
146149
B: Batch size
147150
M: Number of features
148151
seqlen: Maximum sequence length
149152
sparsity: Sparsity factor (not used)
150153
151154
Returns:
152-
Tensor of shape (B, M) with mean values per row and feature
155+
Callable that returns tensor of shape (B, M) with mean values per row and feature
153156
"""
154157
x_values = x._values
155158
x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue]
@@ -160,7 +163,7 @@ def jagged_mean_tritonbench(
160163
dtype=torch.int32,
161164
device=x_values.device, # pyright: ignore[reportAttributeAccessIssue]
162165
)
163-
return jagged_mean_kernel(x_values, x_offsets, feature_counts, M)
166+
return lambda: jagged_mean_kernel(x_values, x_offsets, feature_counts, M)
164167

165168

166169
# %%

examples/jagged_softmax.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from __future__ import annotations
1212

1313
import itertools
14+
from typing import Callable
1415

1516
import torch
1617

@@ -135,22 +136,23 @@ def jagged_softmax_kernel(
135136
# Benchmark Wrapper
136137
# --------------
137138
def jagged_softmax_tritonbench(
138-
x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
139-
) -> torch.Tensor:
139+
tb_op: object, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
140+
) -> Callable[[], torch.Tensor]:
140141
"""
141142
Wrapper for tritonbench that matches the expected interface.
142143
143144
Args:
145+
tb_op: TritonBench operator instance
144146
x: Nested tensor in jagged format with shape (B, *, M)
145147
B: Batch size (unused)
146148
M: Number of features (unused)
147149
seqlen: Maximum sequence length (unused)
148150
sparsity: Sparsity factor (unused)
149151
150152
Returns:
151-
Tensor of shape (N, M), where N = total number of rows in the jagged tensor
153+
Callable that returns tensor of shape (N, M), where N = total number of rows in the jagged tensor
152154
"""
153-
return jagged_softmax_kernel(x._values, x._offsets) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue]
155+
return lambda: jagged_softmax_kernel(x._values, x._offsets) # pyright: ignore[reportArgumentType, reportAttributeAccessIssue]
154156

155157

156158
# %%

examples/matmul.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,12 @@ def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
132132

133133
# %%
134134
def matmul_tritonbench(
135-
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None
135+
tb_op: object, a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None
136136
) -> Callable:
137137
"""
138138
Wrapper for tritonbench that matches its interface.
139139
Args:
140+
tb_op: TritonBench operator instance
140141
a (torch.Tensor): Left matrix.
141142
b (torch.Tensor): Right matrix.
142143
bias (torch.Tensor or None): Optional bias to add in the epilogue.
@@ -148,7 +149,9 @@ def matmul_tritonbench(
148149
return lambda: matmul(a, b)
149150

150151

151-
def addmm_tritonbench(bias: Tensor, mat1: Tensor, mat2: Tensor) -> Callable:
152+
def addmm_tritonbench(
153+
tb_op: object, bias: Tensor, mat1: Tensor, mat2: Tensor
154+
) -> Callable:
152155
"""
153156
Wrapper for tritonbench that performs a matrix multiplication of the matrices
154157
`mat1` and `mat2` followed by adding `bias` to the result.

examples/matmul_split_k.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,11 +97,12 @@ def check(m: int, k: int, n: int) -> None:
9797

9898
# %%
9999
def matmul_split_k_tritonbench(
100-
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None
100+
tb_op: object, a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None
101101
) -> Callable:
102102
"""
103103
Wrapper for tritonbench that matches its interface.
104104
Args:
105+
tb_op: TritonBench operator instance
105106
a (torch.Tensor): Left input matrix.
106107
b (torch.Tensor): Right input matrix.
107108
bias (torch.Tensor or None): Optional bias to add in the epilogue.

0 commit comments

Comments
 (0)