Skip to content

Commit 48cfdec

Browse files
committed
Add jagged_mean example
stack-info: PR: #263, branch: yf225/stack/9
1 parent 110df59 commit 48cfdec

File tree

3 files changed

+465
-1
lines changed

3 files changed

+465
-1
lines changed

examples/jagged_mean.py

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
from __future__ import annotations
2+
3+
import torch
4+
5+
import helion
6+
from helion._testing import run_example
7+
import helion.language as hl
8+
9+
10+
@helion.kernel()
11+
def jagged_mean_kernel(
12+
x_data: torch.Tensor,
13+
x_offsets: torch.Tensor,
14+
x_feature_counts: torch.Tensor, # [num_rows] - number of features per row
15+
max_M_tensor: torch.Tensor, # Dummy tensor whose size indicates max features
16+
) -> torch.Tensor:
17+
"""
18+
Compute the mean of each row in a jagged tensor with variable features per row.
19+
20+
Args
21+
----
22+
x_data : 2-D tensor of shape (total_elements, max_M) holding all elements.
23+
x_offsets : (num_rows + 1) tensor. Row i is the slice
24+
x_data[x_offsets[i] : x_offsets[i+1], :].
25+
x_feature_counts: (num_rows) tensor. Number of valid features for each row.
26+
max_M_tensor : Dummy tensor whose numel() gives max number of features.
27+
28+
Returns
29+
-------
30+
result : 2-D tensor of shape (num_rows, max_M) containing the mean of each row.
31+
Invalid features (beyond x_feature_counts[i]) are set to 0.
32+
"""
33+
num_rows = x_offsets.size(0) - 1
34+
max_M = max_M_tensor.numel() # Extract max features from dummy tensor
35+
36+
out = torch.zeros([num_rows, max_M], dtype=x_data.dtype, device=x_data.device)
37+
38+
# Flatten x_data for easier indexing
39+
x_flat = x_data.view(-1)
40+
41+
# Process rows in tiles
42+
for tile_b in hl.tile(num_rows):
43+
starts = x_offsets[tile_b]
44+
ends = x_offsets[tile_b.index + 1]
45+
nnz = ends - starts
46+
max_nnz = nnz.amax()
47+
48+
# Get feature counts for this tile of rows
49+
feature_counts = x_feature_counts[tile_b]
50+
51+
# Process features in tiles
52+
for tile_m in hl.tile(max_M):
53+
# Create mask for valid features
54+
feature_valid = tile_m.index < feature_counts[:, None]
55+
56+
# Initialize accumulator
57+
row_sums = hl.zeros([tile_b, tile_m], dtype=x_data.dtype)
58+
59+
# Process elements within each row
60+
for tile_k in hl.tile(0, max_nnz):
61+
# Compute flattened indices
62+
base_indices = starts[:, None] + tile_k.index[None, :]
63+
flat_indices = (
64+
base_indices[:, :, None] * max_M + tile_m.index[None, None, :]
65+
)
66+
67+
# Combined mask: valid row element AND valid feature
68+
row_mask = tile_k.index[None, :] < nnz[:, None]
69+
combined_mask = row_mask[:, :, None] & feature_valid[:, None, :]
70+
71+
x_slice = hl.load(
72+
x_flat,
73+
[flat_indices],
74+
extra_mask=combined_mask,
75+
)
76+
# Accumulate - sum across the k dimension (dim=1)
77+
row_sums = row_sums + x_slice.sum(dim=1)
78+
79+
# Compute mean
80+
nnz_float = nnz.to(x_data.dtype)
81+
nnz_expanded = nnz_float[:, None]
82+
83+
# Compute result with feature masking
84+
result = torch.where(nnz_expanded > 0, row_sums / nnz_expanded, 0.0)
85+
86+
# Apply feature mask to output
87+
out[tile_b, tile_m] = torch.where(feature_valid, result, 0.0)
88+
89+
return out
90+
91+
92+
def reference_jagged_mean_kernel_pytorch(
93+
x_data: torch.Tensor,
94+
x_offsets: torch.Tensor,
95+
x_feature_counts: torch.Tensor,
96+
max_M: int,
97+
) -> torch.Tensor:
98+
"""PyTorch reference implementation for jagged mean with variable features."""
99+
num_rows = x_offsets.numel() - 1
100+
out = torch.zeros((num_rows, max_M), dtype=x_data.dtype, device=x_data.device)
101+
for i in range(num_rows):
102+
start = int(x_offsets[i])
103+
end = int(x_offsets[i + 1])
104+
num_features = int(x_feature_counts[i])
105+
if end > start and num_features > 0:
106+
out[i, :num_features] = x_data[start:end, :num_features].mean(dim=0)
107+
return out
108+
109+
110+
def jagged_mean_tritonbench(
111+
x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
112+
) -> torch.Tensor:
113+
"""
114+
Wrapper for tritonbench that matches the expected interface.
115+
116+
Args:
117+
x: Nested tensor in jagged format with shape (B, *, M)
118+
B: Batch size
119+
M: Number of features
120+
seqlen: Maximum sequence length
121+
sparsity: Sparsity factor (not used)
122+
123+
Returns:
124+
Tensor of shape (B, M) with mean values per row and feature
125+
"""
126+
x_values = x._values
127+
x_offsets = x._offsets # pyright: ignore[reportAttributeAccessIssue]
128+
129+
feature_counts = torch.full(
130+
(B,),
131+
M,
132+
dtype=torch.int32,
133+
device=x_values.device, # pyright: ignore[reportAttributeAccessIssue]
134+
)
135+
max_M_tensor = torch.empty(M, device=x_values.device) # pyright: ignore[reportAttributeAccessIssue]
136+
137+
return jagged_mean_kernel(x_values, x_offsets, feature_counts, max_M_tensor)
138+
139+
140+
def main() -> None:
141+
num_rows, max_cols = 32, 64
142+
device = "cuda"
143+
144+
lengths = torch.randint(1, max_cols + 1, (num_rows,), device=device)
145+
x_offsets = torch.cat(
146+
[torch.zeros(1, dtype=torch.long, device=device), torch.cumsum(lengths, dim=0)]
147+
)
148+
nnz = int(x_offsets[-1])
149+
M = 8 # number of features
150+
x_data = torch.randn(nnz, M, dtype=torch.float32, device=device)
151+
feature_counts = torch.randint(
152+
1, M + 1, (num_rows,), dtype=torch.int32, device=device
153+
)
154+
max_M_tensor = torch.empty(M, device=device)
155+
156+
run_example(
157+
lambda x, o, fc, mt: jagged_mean_kernel(x, o, fc, mt),
158+
lambda x, o, fc, mt: reference_jagged_mean_kernel_pytorch(x, o, fc, mt.numel()),
159+
(x_data, x_offsets, feature_counts, max_M_tensor),
160+
)
161+
162+
163+
if __name__ == "__main__":
164+
main()

0 commit comments

Comments
 (0)