Skip to content

Commit fc73acb

Browse files
committed
jagged hstu example
ghstack-source-id: e114db1 Pull Request resolved: #527
1 parent 3c0348a commit fc73acb

File tree

2 files changed

+257
-0
lines changed

2 files changed

+257
-0
lines changed

benchmarks/run.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,12 @@
3838
KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType]
3939
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
4040
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
41+
"ragged_attention": (
42+
"tritonbench.operators.ragged_attention.operator",
43+
"examples.jagged_hstu_attn",
44+
"ragged_attention_wrapper",
45+
{"target_size": 0},
46+
),
4147
"embedding": (
4248
"tritonbench.operators.embedding.operator",
4349
"examples.embedding",

examples/jagged_hstu_attn.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
"""
2+
Simplified Jagged HSTU Attention Forward Example
3+
===============================================
4+
5+
This example demonstrates a simplified version of jagged HSTU attention using Helion.
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
import torch
14+
15+
import helion
16+
from helion._testing import run_example
17+
import helion.language as hl
18+
19+
try:
20+
from generative_recommenders.ops.triton.triton_hstu_attention import ( # pyright: ignore[reportMissingImports]
21+
triton_hstu_mha,
22+
)
23+
24+
HAS_HAMMER = True
25+
except ImportError:
26+
HAS_HAMMER = False
27+
28+
29+
def reference_ragged_hstu_kernel_pytorch(
30+
q: torch.Tensor,
31+
k: torch.Tensor,
32+
v: torch.Tensor,
33+
seq_offsets: torch.Tensor,
34+
num_targets: torch.Tensor | None,
35+
max_seq_len: int,
36+
) -> torch.Tensor:
37+
"""Simple PyTorch implementation of HSTU ragged attention"""
38+
# Initialize output
39+
output = torch.zeros_like(v)
40+
41+
# Scale factor
42+
scale = 1.0 / max_seq_len
43+
alpha = 1.0 / v.size(2) ** 2
44+
45+
# Compute per-batch sequence lengths
46+
seq_lens = seq_offsets[1:] - seq_offsets[:-1]
47+
48+
q_split = torch.split(q, seq_lens.tolist(), dim=0)
49+
k_split = torch.split(k, seq_lens.tolist(), dim=0)
50+
v_split = torch.split(v, seq_lens.tolist(), dim=0)
51+
52+
# Get the batches
53+
for i, (q_batch, k_batch, v_batch) in enumerate(
54+
zip(q_split, k_split, v_split, strict=False)
55+
):
56+
q_batch = q_batch.transpose(0, 1) # [heads, seq_len, head_dim]
57+
k_batch = k_batch.permute(1, 2, 0) # [heads, head_dim, seq_len]
58+
v_batch = v_batch.transpose(0, 1) # [heads, seq_len, head_dim]
59+
60+
# Compute attention scores using batch matrix multiplication
61+
scores = torch.bmm(q_batch, k_batch) * alpha
62+
63+
# Apply SiLU activation
64+
scores = (scores / (1.0 + torch.exp(-scores))) * scale
65+
66+
# Apply lower triangular mask (causal attention)
67+
invalid_mask = torch.tril(torch.ones_like(scores, dtype=torch.bool), diagonal=0)
68+
scores = torch.where(invalid_mask, scores, torch.zeros_like(scores))
69+
70+
# Compute and store output
71+
output_batch = torch.bmm(scores, v_batch)
72+
output[seq_offsets[i] : seq_offsets[i + 1]] = output_batch.transpose(0, 1)
73+
74+
return output
75+
76+
77+
@helion.kernel()
78+
def _helion_ragged_attention_kernel(
79+
max_seq_len: int,
80+
alpha: float,
81+
q: torch.Tensor,
82+
k: torch.Tensor,
83+
v: torch.Tensor,
84+
seq_offsets: torch.Tensor,
85+
) -> torch.Tensor:
86+
"""Helion implementation of HSTU ragged attention"""
87+
scale = 1.0 / max_seq_len
88+
num_heads = hl.specialize(q.size(1))
89+
num_batches = hl.specialize(seq_offsets.size(0) - 1)
90+
dimV = hl.specialize(v.size(2))
91+
92+
out = torch.zeros_like(v)
93+
94+
# Tile over batch, head, sequence
95+
for tile_b, tile_h, tile_q in hl.tile(
96+
[num_batches, num_heads, max_seq_len], block_size=[1, 1, None]
97+
):
98+
starts = seq_offsets[tile_b.begin]
99+
ends = seq_offsets[tile_b.begin + 1]
100+
seq_len = ends - starts
101+
102+
if tile_q.begin < seq_len:
103+
mask_q = tile_q.index < seq_len
104+
q_blk = q[tile_q.index + starts, tile_h.begin, :]
105+
acc = hl.zeros([tile_q, dimV], dtype=torch.float32)
106+
107+
# Causal attention: only attend to previous tokens
108+
for tile_kv in hl.tile(0, tile_q.end, block_size=None):
109+
mask_kv = tile_kv.index < seq_len
110+
k_blk = k[tile_kv.index + starts, tile_h.begin, :]
111+
v_blk = v[tile_kv.index + starts, tile_h.begin, :]
112+
113+
# Compute attention scores with SiLU activation
114+
scores = (
115+
torch.nn.functional.silu(torch.matmul(q_blk, k_blk.T) * alpha)
116+
* scale
117+
)
118+
119+
# Apply causal mask: only attend to previous positions
120+
scores = torch.where(
121+
(tile_q.index.unsqueeze(1) > tile_kv.index.unsqueeze(0))
122+
& mask_q[:, None]
123+
& mask_kv[None, :],
124+
scores,
125+
0.0,
126+
)
127+
128+
acc += torch.matmul(scores.to(v.dtype), v_blk)
129+
130+
# Store result
131+
out[tile_q.index + starts, tile_h.begin, :] = acc
132+
133+
return out
134+
135+
136+
def ragged_attention_wrapper(
137+
q: torch.Tensor,
138+
k: torch.Tensor,
139+
v: torch.Tensor,
140+
seq_offsets: torch.Tensor,
141+
num_targets: torch.Tensor | None,
142+
max_seq_len: int,
143+
) -> torch.Tensor:
144+
"""Wrapper function for ragged attention kernel"""
145+
return _helion_ragged_attention_kernel(
146+
max_seq_len=max_seq_len,
147+
alpha=1.0 / v.size(2) ** 2,
148+
q=q,
149+
k=k,
150+
v=v,
151+
seq_offsets=seq_offsets,
152+
)
153+
154+
155+
def test(
156+
batch_size: int,
157+
max_seq_len: int,
158+
heads: int,
159+
head_dim: int,
160+
dtype: torch.dtype = torch.bfloat16,
161+
device: torch.device | str = "cuda",
162+
) -> None:
163+
"""
164+
Test the ragged HSTU attention kernel implementation.
165+
166+
Args:
167+
batch_size: Number of sequences in the batch
168+
max_seq_len: Maximum sequence length
169+
heads: Number of attention heads
170+
head_dim: Dimension of each attention head
171+
dtype: Data type for the tensors
172+
device: Device to run the test on
173+
"""
174+
device = torch.device(device)
175+
176+
# Generate random sequence lengths
177+
min_seq_len = max_seq_len // 2
178+
seq_lengths = torch.randint(
179+
min_seq_len, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device
180+
)
181+
seq_offsets = torch.cat(
182+
[
183+
torch.tensor([0], dtype=torch.int32, device=device),
184+
torch.cumsum(seq_lengths, dim=0),
185+
]
186+
)
187+
total_seq_len = int(seq_offsets[-1].item())
188+
189+
# q, k, v: [total_seq_len, heads, head_dim]
190+
q = torch.randn(
191+
(total_seq_len, heads, head_dim),
192+
dtype=dtype,
193+
device=device,
194+
requires_grad=True,
195+
)
196+
k = torch.randn(
197+
(total_seq_len, heads, head_dim),
198+
dtype=dtype,
199+
device=device,
200+
requires_grad=True,
201+
)
202+
v = torch.randn(
203+
(total_seq_len, heads, head_dim),
204+
dtype=dtype,
205+
device=device,
206+
requires_grad=True,
207+
)
208+
209+
baselines = {
210+
"torch": reference_ragged_hstu_kernel_pytorch,
211+
}
212+
if HAS_HAMMER:
213+
214+
def _triton_hstu_mha(
215+
q: torch.Tensor,
216+
k: torch.Tensor,
217+
v: torch.Tensor,
218+
seq_offsets: torch.Tensor,
219+
num_targets: torch.Tensor | None,
220+
max_seq_len: int,
221+
) -> torch.Tensor:
222+
return triton_hstu_mha( # pyright: ignore[reportPossiblyUnboundVariable,reportCallIssue]
223+
max_seq_len,
224+
alpha=1.0 / v.size(2) ** 2,
225+
q=q,
226+
k=k,
227+
v=v,
228+
seq_offsets=seq_offsets,
229+
num_targets=num_targets,
230+
max_attn_len=0,
231+
contextual_seq_len=0,
232+
)
233+
234+
baselines["tritonbench"] = _triton_hstu_mha
235+
236+
run_example(
237+
ragged_attention_wrapper,
238+
baselines,
239+
(q, k, v, seq_offsets, None, max_seq_len),
240+
)
241+
242+
243+
def main() -> None:
244+
"""
245+
Main entry point for testing the simplified ragged HSTU attention kernel.
246+
"""
247+
test(batch_size=1024, max_seq_len=1024, heads=4, head_dim=128, dtype=torch.bfloat16)
248+
249+
250+
if __name__ == "__main__":
251+
main()

0 commit comments

Comments
 (0)