Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions benchmarks/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,11 @@ class RunResult:
"num_inputs": 10, # int4_gemm takes long time on Benchmark CI, so use fewer inputs instead.
},
),
"jagged_sum": (
"tritonbench.operators.jagged_sum.operator",
"examples.jagged_sum",
"jagged_sum_tritonbench",
),
}


Expand Down Expand Up @@ -410,6 +415,14 @@ class RunResult:
"helion_grouped_gemm_jagged_persistent_tritonbench-speedup": "helion_speedup",
"helion_grouped_gemm_jagged_persistent_tritonbench-accuracy": "helion_accuracy",
},
"jagged_sum": {
"triton_jagged_sum_no_pad_simple_fused-speedup": "triton_speedup",
"triton_jagged_sum_no_pad_simple_fused-accuracy": "triton_accuracy",
"torch_compile_nested_tensor_integration-speedup": "torch_compile_speedup",
"torch_compile_nested_tensor_integration-accuracy": "torch_compile_accuracy",
"helion_jagged_sum_tritonbench-speedup": "helion_speedup",
"helion_jagged_sum_tritonbench-accuracy": "helion_accuracy",
},
"addmm": {
"aten_addmm": "baseline",
"triton_addmm-speedup": "triton_speedup",
Expand Down
197 changes: 197 additions & 0 deletions examples/jagged_sum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""
Jagged Mean Example
===============

This example demonstrates how to compute the mean of each row in a jagged tensor
with variable features per row using Helion.
"""

# %%
# Imports
# -------
from __future__ import annotations

from typing import Callable

import torch

import helion
from helion._testing import run_example
import helion.language as hl


# %%
# Jagged Mean Kernel
# ---------------
@helion.kernel()
def jagged_sum_kernel(
x_data: torch.Tensor,
x_offsets: torch.Tensor,
) -> torch.Tensor:
"""
Compute the mean of each row in a jagged tensor with variable features per row.

Args:
x_data: 2-D tensor of shape (total_elements, M) holding all elements
x_offsets: (num_rows + 1) tensor. Row i is the slice
x_data[x_offsets[i] : x_offsets[i+1], :]

Returns:
2-D tensor of shape (num_rows, M) containing the sum of jagged dimension.
"""
M = x_data.shape[1]
num_rows = x_offsets.size(0) - 1

out = torch.zeros([num_rows, M], dtype=x_data.dtype, device=x_data.device)

# Flatten x_data for easier indexing
x_flat = x_data.view(-1)

# Process rows in tiles
for tile_b in hl.tile(num_rows):
starts = x_offsets[tile_b]
ends = x_offsets[tile_b.index + 1]
nnz = ends - starts
max_nnz = nnz.amax()

# Process features in tiles
for tile_m in hl.tile(M):
# Initialize accumulator
row_sums = hl.zeros([tile_b, tile_m], dtype=x_data.dtype)

# Process elements within each row
for tile_k in hl.tile(0, max_nnz):
# Compute flattened indices
base_indices = starts[:, None] + tile_k.index[None, :]
flat_indices = (
base_indices[:, :, None] * M + tile_m.index[None, None, :]
)

# Combined mask: valid row element AND valid feature
row_mask = tile_k.index[None, :] < nnz[:, None]
combined_mask = row_mask[:, :, None]

x_slice = hl.load(
x_flat,
[flat_indices],
extra_mask=combined_mask,
)
# Accumulate - sum across the k dimension (dim=1)
row_sums = row_sums + x_slice.sum(dim=1)

# Apply feature mask to output
out[tile_b, tile_m] = row_sums

return out


# %%
# Reference Implementation
# --------------------
def reference_jagged_sum_kernel_pytorch(
x_data: torch.Tensor,
x_offsets: torch.Tensor,
) -> torch.Tensor:
"""
PyTorch reference implementation for jagged mean with variable features.

Args:
x_data: 2-D tensor holding all elements
x_offsets: Offsets tensor for row indexing

Returns:
Tensor containing the mean of each row
"""
num_rows = x_offsets.numel() - 1
M = x_data.size(1)
out = torch.zeros((num_rows, M), dtype=x_data.dtype, device=x_data.device)
for i in range(num_rows):
start = int(x_offsets[i])
end = int(x_offsets[i + 1])
if end > start:
out[i, :] = x_data[start:end, :].sum(dim=0)
return out


# %%
# Benchmark Wrapper
# --------------
def jagged_sum_tritonbench(
tb_op: object, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
) -> Callable[[], torch.Tensor]:
"""
Wrapper for tritonbench that matches the expected interface.

Args:
tb_op: TritonBench operator instance
x: Nested tensor in jagged format with shape (B, *, M)
B: Batch size
M: Number of features
seqlen: Maximum sequence length
sparsity: Sparsity factor (not used)

Returns:
Callable that returns tensor of shape (B, M) with mean values per row and feature
"""
x_values = x._values
x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue]

return lambda: jagged_sum_kernel(x_values, x_offsets)


# %%
# Helper function to create test data
# ---------------------------------
def create_test_jagged_tensor(
B: int,
M: int,
max_seqlen: int,
device: str = "cuda",
dtype: torch.dtype = torch.float32,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Create test jagged tensor data."""

# Generate random sequence lengths
seq_lengths = torch.randint(1, max_seqlen + 1, (B,), device=device)

# Create offsets
x_offsets = torch.cat(
[
torch.zeros(1, dtype=torch.long, device=device),
torch.cumsum(seq_lengths, dim=0),
]
)

# Create values
nnz = int(x_offsets[-1])
x_data = torch.randn(nnz, M, dtype=dtype, device=device)

return x_data, x_offsets


# %%
# Main Function
# -----------
def main() -> None:
"""
Main entry point that runs the jagged mean kernel verification.

Creates test data with random jagged tensors and feature counts, then compares
the kernel implementation against the PyTorch reference implementation.
"""
B, M, max_seqlen = 8, 128, 64
device = "cuda"

x_data, x_offsets = create_test_jagged_tensor(
B, M, max_seqlen, device, dtype=torch.float32
)

run_example(
lambda x, o: jagged_sum_kernel(x, o),
lambda x, o: reference_jagged_sum_kernel_pytorch(x, o),
(x_data, x_offsets),
)


if __name__ == "__main__":
main()
81 changes: 81 additions & 0 deletions test/test_examples.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1758,6 +1758,87 @@ def jagged_softmax_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _lau
_launcher(_helion_jagged_softmax_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_flat, out, out.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
return out.reshape(N, M)

--- assertExpectedJournal(TestExamples.test_jagged_sum)
from __future__ import annotations

import torch
import triton
import triton.language as tl
from helion.runtime import default_launcher as _default_launcher

@triton.jit
def _helion_jagged_sum_kernel(x_offsets, x_flat, out, out_stride_0, out_stride_1, x_flat_stride_0, x_offsets_stride_0, num_rows, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
pid_0 = tl.program_id(0)
offset_0 = pid_0 * _BLOCK_SIZE_0
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
mask_0 = indices_0 < num_rows
starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0)
v_0 = tl.full([], 1, tl.int32)
v_1 = indices_0 + v_0
ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0)
v_2 = ends - starts
_mask_to = tl.where(mask_0, v_2, tl.full([], -9223372036854775808, tl.int64))
max_nnz = tl.cast(tl.max(_mask_to, 0), tl.int64)
for offset_1 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1):
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
mask_1 = indices_1 < M
max_nnz_copy = max_nnz
starts_copy = starts
v_2_copy = v_2
max_nnz_copy_0 = max_nnz_copy
starts_copy_0 = starts_copy
v_2_copy_0 = v_2_copy
row_sums = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
for offset_2 in tl.range(0, max_nnz_copy_0.to(tl.int32), _BLOCK_SIZE_2):
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
mask_2 = indices_2 < max_nnz_copy_0
starts_copy_0_copy = starts_copy_0
v_2_copy_0_copy = v_2_copy_0
row_sums_copy = row_sums
starts_copy_0_copy_0 = starts_copy_0_copy
v_2_copy_0_copy_0 = v_2_copy_0_copy
row_sums_copy_0 = row_sums_copy
subscript = starts_copy_0_copy_0[:, None]
subscript_1 = indices_2[None, :]
v_3 = tl.cast(subscript_1, tl.int64)
v_4 = subscript + v_3
subscript_2 = v_4[:, :, None]
v_5 = subscript_2 * M
subscript_3 = indices_1[None, None, :]
v_6 = tl.cast(subscript_3, tl.int64)
v_7 = v_5 + v_6
subscript_4 = indices_2[None, :]
subscript_5 = v_2_copy_0_copy_0[:, None]
v_8 = tl.cast(subscript_4, tl.int64)
v_9 = v_8 < subscript_5
combined_mask = v_9[:, :, None]
x_slice = tl.load(x_flat + v_7 * x_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :] & combined_mask, other=0)
sum_1 = tl.cast(tl.sum(x_slice, 1), tl.float32)
row_sums = row_sums_copy_0 + sum_1
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), row_sums, mask_0[:, None] & mask_1[None, :])

def jagged_sum_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launcher=_default_launcher):
"""
Compute the mean of each row in a jagged tensor with variable features per row.

Args:
x_data: 2-D tensor of shape (total_elements, M) holding all elements
x_offsets: (num_rows + 1) tensor. Row i is the slice
x_data[x_offsets[i] : x_offsets[i+1], :]

Returns:
2-D tensor of shape (num_rows, M) containing the sum of jagged dimension.
"""
M = x_data.shape[1]
num_rows = x_offsets.size(0) - 1
out = torch.zeros([num_rows, M], dtype=x_data.dtype, device=x_data.device)
x_flat = x_data.view(-1)
_BLOCK_SIZE_0 = 16
_BLOCK_SIZE_1 = 8
_BLOCK_SIZE_2 = 16
_launcher(_helion_jagged_sum_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_flat, out, out.stride(0), out.stride(1), x_flat.stride(0), x_offsets.stride(0), num_rows, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
return out

--- assertExpectedJournal(TestExamples.test_jsd)
from __future__ import annotations

Expand Down
28 changes: 28 additions & 0 deletions test/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,34 @@ def test_int4_gemm(self):
)
)

def test_jagged_sum(self):
num_rows, max_cols = 128, 64
M = 8 # number of features
lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE)
x_offsets = torch.cat(
[
torch.zeros(1, dtype=torch.long, device=DEVICE),
torch.cumsum(lengths, dim=0),
]
)
nnz = int(x_offsets[-1])
x_data = torch.randn(nnz, M, dtype=torch.float32, device=DEVICE)
args = (x_data, x_offsets)

# Import and use the reference implementation
mod = import_path(EXAMPLES_DIR / "jagged_sum.py")
expected = mod.reference_jagged_sum_kernel_pytorch(x_data, x_offsets)

self.assertExpectedJournal(
check_example(
"jagged_sum",
args,
expected,
fn_name="jagged_sum_kernel",
block_sizes=[16, 8, 16],
)
)

def test_fused_linear_jsd(self):
beta = 0.5
ignore_index = 1
Expand Down
Loading