Skip to content

Commit 0798aad

Browse files
committed
jagged hstu example
ghstack-source-id: 0be580b Pull Request resolved: #527
1 parent 3c0348a commit 0798aad

File tree

1 file changed

+292
-0
lines changed

1 file changed

+292
-0
lines changed

examples/jagged_hstu_attn.py

Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
"""
2+
Simplified Jagged HSTU Attention Forward Example
3+
===============================================
4+
5+
This example demonstrates a simplified version of jagged HSTU attention using Helion.
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
from typing import Any
14+
15+
import torch
16+
17+
import helion
18+
from helion._testing import run_example
19+
import helion.language as hl
20+
21+
try:
22+
from generative_recommenders.ops.triton.triton_hstu_attention import ( # pyright: ignore[reportMissingImports]
23+
triton_hstu_mha,
24+
)
25+
26+
HAS_HAMMER = True
27+
except ImportError:
28+
HAS_HAMMER = False
29+
30+
31+
def generate_inputs() -> dict[str, Any]:
32+
"""Generate small inputs for HSTU attention for easier verification and testing"""
33+
34+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35+
dtype = torch.bfloat16
36+
37+
batch_size = 1024
38+
max_seq_len = 1024 # N parameter
39+
heads = 4
40+
head_dim = 128
41+
42+
# Generate random sequence lengths
43+
min_seq_len = max_seq_len // 2
44+
seq_lengths = torch.randint(
45+
min_seq_len, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device
46+
)
47+
seq_offsets = torch.cat(
48+
[
49+
torch.tensor([0], dtype=torch.int32, device=device),
50+
torch.cumsum(seq_lengths, dim=0),
51+
]
52+
)
53+
total_seq_len = int(seq_offsets[-1].item())
54+
55+
# Generate tensors with ragged sequence length
56+
# q, k, v: [total_seq_len, heads, head_dim]
57+
q = torch.randn(
58+
(total_seq_len, heads, head_dim),
59+
dtype=dtype,
60+
device=device,
61+
requires_grad=True,
62+
)
63+
k = torch.randn(
64+
(total_seq_len, heads, head_dim),
65+
dtype=dtype,
66+
device=device,
67+
requires_grad=True,
68+
)
69+
v = torch.randn(
70+
(total_seq_len, heads, head_dim),
71+
dtype=dtype,
72+
device=device,
73+
requires_grad=True,
74+
)
75+
76+
# Parameters
77+
alpha = 1.0 / (head_dim**0.5) # Scaling factor
78+
invalid_attn_mask_type = "lower_triangular"
79+
80+
# Optional parameters (set to None for simplicity)
81+
num_targets = None
82+
attn_scale = None
83+
attn_bias = None
84+
seq2_offsets = None
85+
86+
# Integer parameters
87+
max_attn_len = 0
88+
contextual_seq_len = 0
89+
sort_by_length = False
90+
full_attn_size = 0
91+
92+
return {
93+
"N": max_seq_len,
94+
"alpha": alpha,
95+
"q": q,
96+
"k": k,
97+
"v": v,
98+
"seq_offsets": seq_offsets,
99+
"invalid_attn_mask_type": invalid_attn_mask_type,
100+
"num_targets": num_targets,
101+
"attn_scale": attn_scale,
102+
"attn_bias": attn_bias,
103+
"seq2_offsets": seq2_offsets,
104+
"max_attn_len": max_attn_len,
105+
"contextual_seq_len": contextual_seq_len,
106+
"sort_by_length": sort_by_length,
107+
"full_attn_size": full_attn_size,
108+
}
109+
110+
111+
def reference_jagged_hstu_kernel_pytorch(inputs: dict[str, Any]) -> torch.Tensor:
112+
"""Simple PyTorch implementation of HSTU ragged attention using direct tensor slicing"""
113+
q = inputs["q"]
114+
k = inputs["k"]
115+
v = inputs["v"]
116+
seq_offsets = inputs["seq_offsets"]
117+
alpha = inputs["alpha"]
118+
N = inputs["N"]
119+
120+
# Initialize output
121+
output = torch.zeros_like(v)
122+
123+
# Scale factor
124+
scale = 1.0 / N
125+
126+
# Compute per-batch sequence lengths
127+
seq_lens = seq_offsets[1:] - seq_offsets[:-1]
128+
129+
q_split = torch.split(q, seq_lens.tolist(), dim=0)
130+
k_split = torch.split(k, seq_lens.tolist(), dim=0)
131+
v_split = torch.split(v, seq_lens.tolist(), dim=0)
132+
133+
# Process each sequence in the batch using direct tensor slicing
134+
for i, (q_batch, k_batch, v_batch) in enumerate(
135+
zip(q_split, k_split, v_split, strict=False)
136+
):
137+
q_batch = q_batch.transpose(0, 1) # [heads, seq_len, head_dim]
138+
k_batch = k_batch.permute(1, 2, 0) # [heads, head_dim, seq_len]
139+
v_batch = v_batch.transpose(0, 1) # [heads, seq_len, head_dim]
140+
141+
# Compute attention scores using batch matrix multiplication
142+
scores = torch.bmm(q_batch, k_batch) * alpha
143+
144+
# Apply SiLU activation
145+
scores = (scores / (1.0 + torch.exp(-scores))) * scale
146+
147+
# Apply invalid mask
148+
if inputs["invalid_attn_mask_type"] == "lower_triangular":
149+
invalid_mask = torch.tril(
150+
torch.ones_like(scores, dtype=torch.bool), diagonal=0
151+
)
152+
scores = torch.where(invalid_mask, scores, torch.zeros_like(scores))
153+
154+
# Compute and store output
155+
output_batch = torch.bmm(scores, v_batch)
156+
output[seq_offsets[i] : seq_offsets[i + 1]] = output_batch.transpose(0, 1)
157+
158+
return output
159+
160+
161+
# @helion.kernel(config=helion.Config(block_sizes=[64, 16], indexing='pointer', l2_groupings=[1], loop_orders=[[2, 0, 1]], num_stages=7, num_warps=4, pid_type='flat', range_flattens=[None, None], range_multi_buffers=[None, None], range_num_stages=[0, 4], range_unroll_factors=[0, 1], range_warp_specializes=[]))
162+
@helion.kernel()
163+
def _helion_ragged_attention_kernel(
164+
q: torch.Tensor,
165+
k: torch.Tensor,
166+
v: torch.Tensor,
167+
seq_offsets: torch.Tensor,
168+
alpha: float,
169+
invalid_mask_type: str,
170+
max_seq_len_tensor: torch.Tensor,
171+
) -> torch.Tensor:
172+
max_seq_len = max_seq_len_tensor.numel()
173+
scale = 1.0 / max_seq_len
174+
175+
num_heads = hl.specialize(q.size(1))
176+
num_batches = hl.specialize(seq_offsets.size(0) - 1)
177+
dimV = hl.specialize(v.size(2))
178+
179+
out = torch.zeros_like(v)
180+
181+
# --- Tile over batch batch, head, sequence ---
182+
for tile_b, tile_h, tile_q in hl.tile(
183+
[num_batches, num_heads, max_seq_len], block_size=[1, 1, None]
184+
):
185+
starts = seq_offsets[tile_b.begin]
186+
ends = seq_offsets[tile_b.begin + 1]
187+
seq_len = ends - starts
188+
189+
if tile_q.begin < seq_len:
190+
mask_q = tile_q.index < seq_len
191+
q_blk = q[tile_q.index + starts, tile_h.begin, :]
192+
acc = hl.zeros([tile_q, dimV], dtype=torch.float32)
193+
194+
if invalid_mask_type == "lower_triangular":
195+
low = 0
196+
high = tile_q.end
197+
else:
198+
low = 0
199+
high = seq_len
200+
201+
for tile_kv in hl.tile(low, high, block_size=None):
202+
mask_kv = tile_kv.index < seq_len
203+
k_blk = k[tile_kv.index + starts, tile_h.begin, :]
204+
v_blk = v[tile_kv.index + starts, tile_h.begin, :]
205+
206+
scores = (
207+
torch.nn.functional.silu(torch.matmul(q_blk, k_blk.T) * alpha)
208+
* scale
209+
)
210+
211+
if invalid_mask_type == "lower_triangular":
212+
scores = torch.where(
213+
(tile_q.index.unsqueeze(1) > tile_kv.index.unsqueeze(0))
214+
& mask_q[:, None]
215+
& mask_kv[None, :],
216+
scores,
217+
0.0,
218+
)
219+
220+
acc += torch.matmul(scores.to(v.dtype), v_blk)
221+
222+
# Store result
223+
out[tile_q.index + starts, tile_h.begin, :] = acc
224+
225+
return out
226+
227+
228+
def helion_ragged_attention_function(inputs: dict[str, Any]) -> torch.Tensor:
229+
"""
230+
Wrapper function for the Helion ragged attention kernel.
231+
"""
232+
233+
return _helion_ragged_attention_kernel(
234+
q=inputs["q"],
235+
k=inputs["k"],
236+
v=inputs["v"],
237+
seq_offsets=inputs["seq_offsets"],
238+
alpha=inputs["alpha"],
239+
invalid_mask_type=inputs["invalid_attn_mask_type"],
240+
max_seq_len_tensor=torch.empty(inputs["N"], device=inputs["q"].device),
241+
)
242+
243+
244+
def tritonbench_hstu_attention_function(inputs: dict[str, Any]) -> torch.Tensor:
245+
"""
246+
Wrapper function for the tritonbench HSTU attention implementation.
247+
248+
Args:
249+
inputs: Dictionary containing all the input parameters
250+
251+
Returns:
252+
Output tensor from tritonbench HSTU attention
253+
"""
254+
if not HAS_HAMMER:
255+
# Return a dummy tensor with the same shape as expected output
256+
return torch.zeros_like(inputs["v"])
257+
258+
return triton_hstu_mha( # pyright: ignore[reportCallIssue,reportPossiblyUnboundVariable]
259+
N=inputs["N"],
260+
alpha=inputs["alpha"],
261+
q=inputs["q"],
262+
k=inputs["k"],
263+
v=inputs["v"],
264+
seq_offsets=inputs["seq_offsets"],
265+
num_targets=inputs["num_targets"],
266+
max_attn_len=inputs["max_attn_len"],
267+
contextual_seq_len=inputs["contextual_seq_len"],
268+
sort_by_length=inputs["sort_by_length"],
269+
)
270+
271+
272+
def main() -> None:
273+
"""
274+
Main entry point for testing the simplified jagged HSTU attention kernel.
275+
"""
276+
inputs = generate_inputs()
277+
278+
baselines = {
279+
"torch": lambda inputs: reference_jagged_hstu_kernel_pytorch(inputs),
280+
}
281+
if HAS_HAMMER:
282+
baselines["tritonbench"] = lambda inputs: tritonbench_hstu_attention_function(
283+
inputs
284+
)
285+
286+
run_example(
287+
lambda inputs: helion_ragged_attention_function(inputs), baselines, (inputs,)
288+
)
289+
290+
291+
if __name__ == "__main__":
292+
main()

0 commit comments

Comments
 (0)