Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 134 additions & 77 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from typing import Callable

# Maps tritonbench op names to Helion kernel examples
KERNEL_MAPPINGS: dict[str, tuple[str, str, str]] = {
# Can map to a single kernel or a list of kernel variants
KERNEL_MAPPINGS: dict[str, tuple[str, str, str] | tuple[str, list[tuple[str, str]]]] = {
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
"embedding": (
Expand Down Expand Up @@ -80,6 +81,14 @@
"examples.layer_norm",
"layer_norm_fwd",
),
# Multiple kernel variants:
"gemm": (
"tritonbench.operators.gemm.operator",
[
("examples.matmul", "matmul_tritonbench"),
("examples.matmul_split_k", "matmul_split_k_tritonbench"),
],
),
}


Expand Down Expand Up @@ -210,7 +219,7 @@ def run_kernel(
tritonbench_args: list[str],
input_shard_info: tuple[int, int] | None = None,
) -> None:
"""Run a single kernel benchmark."""
"""Run a kernel benchmark, handling both single and multiple variants."""
# Check if kernel is in the mapping table
if kernel_name not in KERNEL_MAPPINGS:
print(f"Error: Unknown kernel '{kernel_name}'", file=sys.stderr)
Expand All @@ -219,25 +228,33 @@ def run_kernel(
)
sys.exit(1)

tritonbench_module, module_path, func_name = KERNEL_MAPPINGS[kernel_name]
mapping = KERNEL_MAPPINGS[kernel_name]

# Normalize to list of variants format
if len(mapping) == 2 and isinstance(mapping[1], list):
# Multiple variants with shared tritonbench module
tritonbench_module = mapping[0]
variants = mapping[1]
else:
# Single kernel with full mapping - convert to list format
assert len(mapping) == 3 # Type narrowing for pyright
tritonbench_module, module_path, func_name = mapping
variants = [(module_path, func_name)]

# Run all variants in the same benchmark
run_kernel_variants(
kernel_name, tritonbench_module, variants, tritonbench_args, input_shard_info
)

# Import from the mapped module
try:
module = importlib.import_module(module_path)
if not hasattr(module, func_name):
print(
f"Error: Module '{module_path}' does not have a function named '{func_name}'",
file=sys.stderr,
)
sys.exit(1)
kernel_func = getattr(module, func_name)
except ImportError as e:
print(
f"Error: Could not import {func_name} from {module_path}", file=sys.stderr
)
print(f"Import error: {e}", file=sys.stderr)
sys.exit(1)
return

def run_kernel_variants(
kernel_name: str,
tritonbench_module: str,
variants: list[tuple[str, str]],
tritonbench_args: list[str],
input_shard_info: tuple[int, int] | None = None,
) -> None:
"""Run kernel variants in the same benchmark run."""

# Import tritonbench components
try:
Expand All @@ -260,19 +277,26 @@ def run_kernel(
assert "--op" not in tritonbench_args
tritonbench_args = ["--op", operator_name, *tritonbench_args]

# Get module's TRITONBENCH_ARGS if any
module_args = getattr(module, "TRITONBENCH_ARGS", {})
# Collect all module args from all variants
all_module_args = {}
for module_path, _ in variants:
try:
module = importlib.import_module(module_path)
module_args = getattr(module, "TRITONBENCH_ARGS", {})
all_module_args.update(module_args)
except ImportError:
pass

# Add module args to tritonbench_args if not already present
for arg_name, arg_value in module_args.items():
for arg_name, arg_value in all_module_args.items():
arg_flag = f"--{arg_name.replace('_', '-')}"
if arg_flag not in tritonbench_args:
tritonbench_args.extend([arg_flag, str(arg_value)])

# Parse known args and collect unknown ones for operator
tb_args, unknown_args = tb_parser.parse_known_args(tritonbench_args)

# Import and run the operator
# Import and get the operator class
try:
operator_module = importlib.import_module(tritonbench_module)
Operator = operator_module.Operator
Expand All @@ -285,64 +309,97 @@ def run_kernel(
print(f"Import error: {e}", file=sys.stderr)
sys.exit(1)

# Create the benchmark method
def helion_method(
self: object,
*args: object,
) -> Callable[..., object]:
"""Helion implementation."""

# Reset all Helion kernels before creating the benchmark function
# so that each input size can go through its own autotuning.
from helion.runtime.kernel import Kernel

for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, Kernel):
attr.reset()

def _inner() -> Callable[..., Any] | object:
# Force autotuning unless HELION_USE_DEFAULT_CONFIG=1 is set
# This ensures we run autotuning even if the kernel has pre-specified configs
if os.environ.get("HELION_USE_DEFAULT_CONFIG", "0") != "1":
# Find all Kernel objects in the module and force autotuning
for attr_name in dir(module):
attr = getattr(module, attr_name)
if isinstance(attr, Kernel):
attr.settings.force_autotune = True

result = kernel_func(*args)
if callable(result):
return result()
return result

return _inner

# Method name for the benchmark
helion_method_name = f"helion_{kernel_name}"

# Import register_benchmark API
from tritonbench.utils.triton_op import ( # pyright: ignore[reportMissingImports]
register_benchmark,
)

# Use register_benchmark decorator
decorated_method = register_benchmark(
operator_name=operator_name,
func_name=helion_method_name,
baseline=False,
enabled=True,
fwd_only=False,
label=helion_method_name,
)(helion_method)

# Set the decorated method on the Operator class
setattr(Operator, helion_method_name, decorated_method)

print(
f"Running {operator_name} benchmark with Helion implementation...\n",
file=sys.stderr,
)
# Register all variants as separate methods
for module_path, func_name in variants:
# Import the kernel function
try:
module = importlib.import_module(module_path)
if not hasattr(module, func_name):
print(
f"Error: Module '{module_path}' does not have a function named '{func_name}'",
file=sys.stderr,
)
continue
kernel_func = getattr(module, func_name)
except ImportError as e:
print(
f"Error: Could not import {func_name} from {module_path}",
file=sys.stderr,
)
print(f"Import error: {e}", file=sys.stderr)
continue

# Create the benchmark method closure to capture the correct module and function
def create_helion_method(
mod: Any, # noqa: ANN401
kfunc: Callable[..., Any],
) -> Callable[..., Any]:
def helion_method(
self: object,
*args: object,
) -> Callable[..., object]:
"""Helion implementation."""

# Reset all Helion kernels before creating the benchmark function
# so that each input size can go through its own autotuning.
from helion.runtime.kernel import Kernel

for attr_name in dir(mod):
attr = getattr(mod, attr_name)
if isinstance(attr, Kernel):
attr.reset()

def _inner() -> Callable[..., Any] | object:
# Force autotuning unless HELION_USE_DEFAULT_CONFIG=1 is set
# This ensures we run autotuning even if the kernel has pre-specified configs
if os.environ.get("HELION_USE_DEFAULT_CONFIG", "0") != "1":
# Find all Kernel objects in the module and force autotuning
for attr_name in dir(mod):
attr = getattr(mod, attr_name)
if isinstance(attr, Kernel):
attr.settings.force_autotune = True

result = kfunc(*args)
if callable(result):
return result()
return result

return _inner

return helion_method

# Method name for the benchmark
variant_name = func_name
helion_method_name = f"helion_{variant_name}"

# Use register_benchmark decorator
decorated_method = register_benchmark(
operator_name=operator_name,
func_name=helion_method_name,
baseline=False,
enabled=True,
fwd_only=False,
label=helion_method_name,
)(create_helion_method(module, kernel_func))

# Set the decorated method on the Operator class
setattr(Operator, helion_method_name, decorated_method)

if len(variants) == 1:
print(
f"Running {operator_name} benchmark with Helion implementation...\n",
file=sys.stderr,
)
else:
print(
f"Running {operator_name} benchmark with {len(variants)} Helion implementations...\n",
file=sys.stderr,
)

# Create and run the operator with unknown args
op = Operator(tb_args=tb_args, extra_args=unknown_args)
Expand Down
9 changes: 9 additions & 0 deletions examples/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,15 @@ def baseline_wrapper(x: Tensor, y: Tensor) -> Tensor:
)


def matmul_tritonbench(
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None
) -> Callable:
"""Wrapper for tritonbench that matches its interface."""
if bias is not None:
return lambda: matmul(a, b, lambda acc, tile: acc + bias[tile[1]])
return lambda: matmul(a, b)


def main() -> None:
# autotune(1024, 1024, 1024)
check(1024, 1024, 1024)
Expand Down
11 changes: 11 additions & 0 deletions examples/matmul_split_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,17 @@ def check(m: int, k: int, n: int) -> None:
run_example(kernel_with_bias, expected_with_bias, (x, y), atol=1)


def matmul_split_k_tritonbench(
a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor | None
) -> Callable:
"""Wrapper for tritonbench that matches its interface."""
if bias is not None:
return lambda: matmul_split_k(
a, b, epilogue=lambda acc, tile: acc + bias[tile[1]]
)
return lambda: matmul_split_k(a, b)


def main() -> None:
check(64, 32768, 64)

Expand Down
Loading