Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions triton_viz/core/trace.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from copy import deepcopy
from triton.runtime import KernelInterface, Autotuner
from triton.runtime.autotuner import Heuristics
from triton.runtime.interpreter import InterpretedFunction
from triton import JITFunction

Expand Down Expand Up @@ -40,13 +41,13 @@ def add_client(self, new_client: Union[str, Client]) -> None:

def __init__(
self,
runner: Union[JITFunction, InterpretedFunction, Autotuner],
runner: Union[JITFunction, InterpretedFunction, Autotuner, Heuristics],
client: Union[str, Client],
) -> None:
self.fn = runner

def unpack_kernel(
source: Union["Trace", JITFunction, InterpretedFunction],
source: Union["Trace", JITFunction, InterpretedFunction, Heuristics],
) -> tuple[
Optional[JITFunction], Optional[Callable], Optional[InterpretedFunction]
]:
Expand All @@ -57,6 +58,9 @@ def unpack_kernel(
return source, base_fn, InterpretedFunction(base_fn)
if isinstance(source, InterpretedFunction):
return None, source.fn, source
if isinstance(source, Heuristics):
# Heuristics wraps another kernel, recursively unpack it
return unpack_kernel(source.fn)
raise TypeError(f"Unsupported runner type: {type(source)}")

if isinstance(runner, Autotuner):
Expand All @@ -70,6 +74,15 @@ def unpack_kernel(
warmup_runner.fn = self.jit_fn
self.runner = runner
self.warmup_runner = warmup_runner
elif isinstance(runner, Heuristics):
self.jit_fn, self.base_fn, self.interpreted_fn = unpack_kernel(runner.fn)
# replace the fn with an InterpretedFunction to avoid re-jitting
runner.fn = self.interpreted_fn
# make a deepcopy of the runner for warmup
warmup_runner = deepcopy(runner)
warmup_runner.fn = self.jit_fn
self.runner = runner
self.warmup_runner = warmup_runner
else:
self.jit_fn, self.base_fn, self.interpreted_fn = unpack_kernel(runner)
self.runner = self.interpreted_fn
Expand Down Expand Up @@ -124,7 +137,9 @@ def decorator(kernel) -> Trace:
return kernel

# First-time wrapping
if isinstance(kernel, (JITFunction, InterpretedFunction, Autotuner)):
if isinstance(
kernel, (JITFunction, InterpretedFunction, Autotuner, Heuristics)
):
return Trace(kernel, clients)

# If the object is already a Trace, just append the new client(s)
Expand Down