Skip to content

Commit de4aa39

Browse files
authored
[Benchmark] jagged_sum kernel and test (#676)
1 parent 32c2156 commit de4aa39

File tree

4 files changed

+319
-0
lines changed

4 files changed

+319
-0
lines changed

benchmarks/run.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,11 @@ class RunResult:
260260
"num_inputs": 10, # int4_gemm takes long time on Benchmark CI, so use fewer inputs instead.
261261
},
262262
),
263+
"jagged_sum": (
264+
"tritonbench.operators.jagged_sum.operator",
265+
"examples.jagged_sum",
266+
"jagged_sum_tritonbench",
267+
),
263268
}
264269

265270

@@ -417,6 +422,14 @@ class RunResult:
417422
"helion_grouped_gemm_jagged_persistent_tritonbench-speedup": "helion_speedup",
418423
"helion_grouped_gemm_jagged_persistent_tritonbench-accuracy": "helion_accuracy",
419424
},
425+
"jagged_sum": {
426+
"triton_jagged_sum_no_pad_simple_fused-speedup": "triton_speedup",
427+
"triton_jagged_sum_no_pad_simple_fused-accuracy": "triton_accuracy",
428+
"torch_compile_nested_tensor_integration-speedup": "torch_compile_speedup",
429+
"torch_compile_nested_tensor_integration-accuracy": "torch_compile_accuracy",
430+
"helion_jagged_sum_tritonbench-speedup": "helion_speedup",
431+
"helion_jagged_sum_tritonbench-accuracy": "helion_accuracy",
432+
},
420433
"addmm": {
421434
"aten_addmm": "baseline",
422435
"triton_addmm-speedup": "triton_speedup",

examples/jagged_sum.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
"""
2+
Jagged Mean Example
3+
===============
4+
5+
This example demonstrates how to compute the mean of each row in a jagged tensor
6+
with variable features per row using Helion.
7+
"""
8+
9+
# %%
10+
# Imports
11+
# -------
12+
from __future__ import annotations
13+
14+
from typing import Callable
15+
16+
import torch
17+
18+
import helion
19+
from helion._testing import run_example
20+
import helion.language as hl
21+
22+
23+
# %%
24+
# Jagged Mean Kernel
25+
# ---------------
26+
@helion.kernel()
27+
def jagged_sum_kernel(
28+
x_data: torch.Tensor,
29+
x_offsets: torch.Tensor,
30+
) -> torch.Tensor:
31+
"""
32+
Compute the mean of each row in a jagged tensor with variable features per row.
33+
34+
Args:
35+
x_data: 2-D tensor of shape (total_elements, M) holding all elements
36+
x_offsets: (num_rows + 1) tensor. Row i is the slice
37+
x_data[x_offsets[i] : x_offsets[i+1], :]
38+
39+
Returns:
40+
2-D tensor of shape (num_rows, M) containing the sum of jagged dimension.
41+
"""
42+
M = x_data.shape[1]
43+
num_rows = x_offsets.size(0) - 1
44+
45+
out = torch.zeros([num_rows, M], dtype=x_data.dtype, device=x_data.device)
46+
47+
# Flatten x_data for easier indexing
48+
x_flat = x_data.view(-1)
49+
50+
# Process rows in tiles
51+
for tile_b in hl.tile(num_rows):
52+
starts = x_offsets[tile_b]
53+
ends = x_offsets[tile_b.index + 1]
54+
nnz = ends - starts
55+
max_nnz = nnz.amax()
56+
57+
# Process features in tiles
58+
for tile_m in hl.tile(M):
59+
# Initialize accumulator
60+
row_sums = hl.zeros([tile_b, tile_m], dtype=x_data.dtype)
61+
62+
# Process elements within each row
63+
for tile_k in hl.tile(0, max_nnz):
64+
# Compute flattened indices
65+
base_indices = starts[:, None] + tile_k.index[None, :]
66+
flat_indices = (
67+
base_indices[:, :, None] * M + tile_m.index[None, None, :]
68+
)
69+
70+
# Combined mask: valid row element AND valid feature
71+
row_mask = tile_k.index[None, :] < nnz[:, None]
72+
combined_mask = row_mask[:, :, None]
73+
74+
x_slice = hl.load(
75+
x_flat,
76+
[flat_indices],
77+
extra_mask=combined_mask,
78+
)
79+
# Accumulate - sum across the k dimension (dim=1)
80+
row_sums = row_sums + x_slice.sum(dim=1)
81+
82+
# Apply feature mask to output
83+
out[tile_b, tile_m] = row_sums
84+
85+
return out
86+
87+
88+
# %%
89+
# Reference Implementation
90+
# --------------------
91+
def reference_jagged_sum_kernel_pytorch(
92+
x_data: torch.Tensor,
93+
x_offsets: torch.Tensor,
94+
) -> torch.Tensor:
95+
"""
96+
PyTorch reference implementation for jagged mean with variable features.
97+
98+
Args:
99+
x_data: 2-D tensor holding all elements
100+
x_offsets: Offsets tensor for row indexing
101+
102+
Returns:
103+
Tensor containing the mean of each row
104+
"""
105+
num_rows = x_offsets.numel() - 1
106+
M = x_data.size(1)
107+
out = torch.zeros((num_rows, M), dtype=x_data.dtype, device=x_data.device)
108+
for i in range(num_rows):
109+
start = int(x_offsets[i])
110+
end = int(x_offsets[i + 1])
111+
if end > start:
112+
out[i, :] = x_data[start:end, :].sum(dim=0)
113+
return out
114+
115+
116+
# %%
117+
# Benchmark Wrapper
118+
# --------------
119+
def jagged_sum_tritonbench(
120+
tb_op: object, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
121+
) -> Callable[[], torch.Tensor]:
122+
"""
123+
Wrapper for tritonbench that matches the expected interface.
124+
125+
Args:
126+
tb_op: TritonBench operator instance
127+
x: Nested tensor in jagged format with shape (B, *, M)
128+
B: Batch size
129+
M: Number of features
130+
seqlen: Maximum sequence length
131+
sparsity: Sparsity factor (not used)
132+
133+
Returns:
134+
Callable that returns tensor of shape (B, M) with mean values per row and feature
135+
"""
136+
x_values = x._values
137+
x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue]
138+
139+
return lambda: jagged_sum_kernel(x_values, x_offsets)
140+
141+
142+
# %%
143+
# Helper function to create test data
144+
# ---------------------------------
145+
def create_test_jagged_tensor(
146+
B: int,
147+
M: int,
148+
max_seqlen: int,
149+
device: str = "cuda",
150+
dtype: torch.dtype = torch.float32,
151+
) -> tuple[torch.Tensor, torch.Tensor]:
152+
"""Create test jagged tensor data."""
153+
154+
# Generate random sequence lengths
155+
seq_lengths = torch.randint(1, max_seqlen + 1, (B,), device=device)
156+
157+
# Create offsets
158+
x_offsets = torch.cat(
159+
[
160+
torch.zeros(1, dtype=torch.long, device=device),
161+
torch.cumsum(seq_lengths, dim=0),
162+
]
163+
)
164+
165+
# Create values
166+
nnz = int(x_offsets[-1])
167+
x_data = torch.randn(nnz, M, dtype=dtype, device=device)
168+
169+
return x_data, x_offsets
170+
171+
172+
# %%
173+
# Main Function
174+
# -----------
175+
def main() -> None:
176+
"""
177+
Main entry point that runs the jagged mean kernel verification.
178+
179+
Creates test data with random jagged tensors and feature counts, then compares
180+
the kernel implementation against the PyTorch reference implementation.
181+
"""
182+
B, M, max_seqlen = 8, 128, 64
183+
device = "cuda"
184+
185+
x_data, x_offsets = create_test_jagged_tensor(
186+
B, M, max_seqlen, device, dtype=torch.float32
187+
)
188+
189+
run_example(
190+
lambda x, o: jagged_sum_kernel(x, o),
191+
lambda x, o: reference_jagged_sum_kernel_pytorch(x, o),
192+
(x_data, x_offsets),
193+
)
194+
195+
196+
if __name__ == "__main__":
197+
main()

test/test_examples.expected

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1758,6 +1758,87 @@ def jagged_softmax_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _lau
17581758
_launcher(_helion_jagged_softmax_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_flat, out, out.stride(0), x_flat.stride(0), x_offsets.stride(0), num_rows, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, _BLOCK_SIZE_3, num_warps=4, num_stages=3)
17591759
return out.reshape(N, M)
17601760

1761+
--- assertExpectedJournal(TestExamples.test_jagged_sum)
1762+
from __future__ import annotations
1763+
1764+
import torch
1765+
import triton
1766+
import triton.language as tl
1767+
from helion.runtime import default_launcher as _default_launcher
1768+
1769+
@triton.jit
1770+
def _helion_jagged_sum_kernel(x_offsets, x_flat, out, out_stride_0, out_stride_1, x_flat_stride_0, x_offsets_stride_0, num_rows, M, _BLOCK_SIZE_0: tl.constexpr, _BLOCK_SIZE_1: tl.constexpr, _BLOCK_SIZE_2: tl.constexpr):
1771+
pid_0 = tl.program_id(0)
1772+
offset_0 = pid_0 * _BLOCK_SIZE_0
1773+
indices_0 = (offset_0 + tl.arange(0, _BLOCK_SIZE_0)).to(tl.int32)
1774+
mask_0 = indices_0 < num_rows
1775+
starts = tl.load(x_offsets + indices_0 * x_offsets_stride_0, mask_0, other=0)
1776+
v_0 = tl.full([], 1, tl.int32)
1777+
v_1 = indices_0 + v_0
1778+
ends = tl.load(x_offsets + v_1 * x_offsets_stride_0, mask_0, other=0)
1779+
v_2 = ends - starts
1780+
_mask_to = tl.where(mask_0, v_2, tl.full([], -9223372036854775808, tl.int64))
1781+
max_nnz = tl.cast(tl.max(_mask_to, 0), tl.int64)
1782+
for offset_1 in tl.range(0, M.to(tl.int32), _BLOCK_SIZE_1):
1783+
indices_1 = offset_1 + tl.arange(0, _BLOCK_SIZE_1).to(tl.int32)
1784+
mask_1 = indices_1 < M
1785+
max_nnz_copy = max_nnz
1786+
starts_copy = starts
1787+
v_2_copy = v_2
1788+
max_nnz_copy_0 = max_nnz_copy
1789+
starts_copy_0 = starts_copy
1790+
v_2_copy_0 = v_2_copy
1791+
row_sums = tl.full([_BLOCK_SIZE_0, _BLOCK_SIZE_1], 0.0, tl.float32)
1792+
for offset_2 in tl.range(0, max_nnz_copy_0.to(tl.int32), _BLOCK_SIZE_2):
1793+
indices_2 = offset_2 + tl.arange(0, _BLOCK_SIZE_2).to(tl.int32)
1794+
mask_2 = indices_2 < max_nnz_copy_0
1795+
starts_copy_0_copy = starts_copy_0
1796+
v_2_copy_0_copy = v_2_copy_0
1797+
row_sums_copy = row_sums
1798+
starts_copy_0_copy_0 = starts_copy_0_copy
1799+
v_2_copy_0_copy_0 = v_2_copy_0_copy
1800+
row_sums_copy_0 = row_sums_copy
1801+
subscript = starts_copy_0_copy_0[:, None]
1802+
subscript_1 = indices_2[None, :]
1803+
v_3 = tl.cast(subscript_1, tl.int64)
1804+
v_4 = subscript + v_3
1805+
subscript_2 = v_4[:, :, None]
1806+
v_5 = subscript_2 * M
1807+
subscript_3 = indices_1[None, None, :]
1808+
v_6 = tl.cast(subscript_3, tl.int64)
1809+
v_7 = v_5 + v_6
1810+
subscript_4 = indices_2[None, :]
1811+
subscript_5 = v_2_copy_0_copy_0[:, None]
1812+
v_8 = tl.cast(subscript_4, tl.int64)
1813+
v_9 = v_8 < subscript_5
1814+
combined_mask = v_9[:, :, None]
1815+
x_slice = tl.load(x_flat + v_7 * x_flat_stride_0, mask_0[:, None, None] & mask_2[None, :, None] & mask_1[None, None, :] & combined_mask, other=0)
1816+
sum_1 = tl.cast(tl.sum(x_slice, 1), tl.float32)
1817+
row_sums = row_sums_copy_0 + sum_1
1818+
tl.store(out + (indices_0[:, None] * out_stride_0 + indices_1[None, :] * out_stride_1), row_sums, mask_0[:, None] & mask_1[None, :])
1819+
1820+
def jagged_sum_kernel(x_data: torch.Tensor, x_offsets: torch.Tensor, *, _launcher=_default_launcher):
1821+
"""
1822+
Compute the mean of each row in a jagged tensor with variable features per row.
1823+
1824+
Args:
1825+
x_data: 2-D tensor of shape (total_elements, M) holding all elements
1826+
x_offsets: (num_rows + 1) tensor. Row i is the slice
1827+
x_data[x_offsets[i] : x_offsets[i+1], :]
1828+
1829+
Returns:
1830+
2-D tensor of shape (num_rows, M) containing the sum of jagged dimension.
1831+
"""
1832+
M = x_data.shape[1]
1833+
num_rows = x_offsets.size(0) - 1
1834+
out = torch.zeros([num_rows, M], dtype=x_data.dtype, device=x_data.device)
1835+
x_flat = x_data.view(-1)
1836+
_BLOCK_SIZE_0 = 16
1837+
_BLOCK_SIZE_1 = 8
1838+
_BLOCK_SIZE_2 = 16
1839+
_launcher(_helion_jagged_sum_kernel, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_flat, out, out.stride(0), out.stride(1), x_flat.stride(0), x_offsets.stride(0), num_rows, M, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
1840+
return out
1841+
17611842
--- assertExpectedJournal(TestExamples.test_jsd)
17621843
from __future__ import annotations
17631844

test/test_examples.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1197,6 +1197,34 @@ def test_int4_gemm(self):
11971197
)
11981198
)
11991199

1200+
def test_jagged_sum(self):
1201+
num_rows, max_cols = 128, 64
1202+
M = 8 # number of features
1203+
lengths = torch.randint(1, max_cols + 1, (num_rows,), device=DEVICE)
1204+
x_offsets = torch.cat(
1205+
[
1206+
torch.zeros(1, dtype=torch.long, device=DEVICE),
1207+
torch.cumsum(lengths, dim=0),
1208+
]
1209+
)
1210+
nnz = int(x_offsets[-1])
1211+
x_data = torch.randn(nnz, M, dtype=torch.float32, device=DEVICE)
1212+
args = (x_data, x_offsets)
1213+
1214+
# Import and use the reference implementation
1215+
mod = import_path(EXAMPLES_DIR / "jagged_sum.py")
1216+
expected = mod.reference_jagged_sum_kernel_pytorch(x_data, x_offsets)
1217+
1218+
self.assertExpectedJournal(
1219+
check_example(
1220+
"jagged_sum",
1221+
args,
1222+
expected,
1223+
fn_name="jagged_sum_kernel",
1224+
block_sizes=[16, 8, 16],
1225+
)
1226+
)
1227+
12001228
def test_fused_linear_jsd(self):
12011229
beta = 0.5
12021230
ignore_index = 1

0 commit comments

Comments
 (0)