Skip to content

Commit ec6099b

Browse files
authored
Add jagged hstu attention example (i.e. ragged_attention) (#527)
1 parent 37d037d commit ec6099b

File tree

4 files changed

+438
-0
lines changed

4 files changed

+438
-0
lines changed

benchmarks/run.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ class RunResult:
6161
KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType]
6262
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
6363
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
64+
"ragged_attention": (
65+
"tritonbench.operators.ragged_attention.operator",
66+
"examples.jagged_hstu_attn",
67+
"ragged_attention_tritonbench",
68+
{"target_size": 0},
69+
),
6470
"embedding": (
6571
"tritonbench.operators.embedding.operator",
6672
"examples.embedding",

examples/jagged_hstu_attn.py

Lines changed: 266 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,266 @@
1+
"""
2+
Simplified Jagged HSTU Attention Forward Example
3+
===============================================
4+
5+
This example demonstrates a simplified version of jagged HSTU attention using Helion.
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
import torch
14+
15+
import helion
16+
from helion._testing import run_example
17+
import helion.language as hl
18+
19+
try:
20+
from generative_recommenders.ops.triton.triton_hstu_attention import ( # pyright: ignore[reportMissingImports]
21+
triton_hstu_mha,
22+
)
23+
24+
HAS_HAMMER = True
25+
except ImportError:
26+
HAS_HAMMER = False
27+
28+
29+
# %%
30+
# Reference Implementation
31+
# --------------------
32+
def reference_jagged_hstu_kernel_pytorch(
33+
q: torch.Tensor,
34+
k: torch.Tensor,
35+
v: torch.Tensor,
36+
seq_offsets: torch.Tensor,
37+
num_targets: torch.Tensor | None,
38+
max_seq_len: int,
39+
) -> torch.Tensor:
40+
"""Simple PyTorch implementation of HSTU jagged attention"""
41+
# Initialize output
42+
output = torch.zeros_like(v)
43+
44+
# Scale factor
45+
scale = 1.0 / max_seq_len
46+
alpha = 1.0 / v.size(2) ** 2
47+
48+
# Compute per-batch sequence lengths
49+
seq_lens = seq_offsets[1:] - seq_offsets[:-1]
50+
51+
q_split = torch.split(q, seq_lens.tolist(), dim=0)
52+
k_split = torch.split(k, seq_lens.tolist(), dim=0)
53+
v_split = torch.split(v, seq_lens.tolist(), dim=0)
54+
55+
# Get the batches
56+
for i, (q_batch, k_batch, v_batch) in enumerate(
57+
zip(q_split, k_split, v_split, strict=False)
58+
):
59+
q_batch = q_batch.transpose(0, 1) # [heads, seq_len, head_dim]
60+
k_batch = k_batch.permute(1, 2, 0) # [heads, head_dim, seq_len]
61+
v_batch = v_batch.transpose(0, 1) # [heads, seq_len, head_dim]
62+
63+
# Compute attention scores using batch matrix multiplication
64+
scores = torch.bmm(q_batch, k_batch) * alpha
65+
66+
# Apply SiLU activation
67+
scores = (scores / (1.0 + torch.exp(-scores))) * scale
68+
69+
# Apply lower triangular mask (causal attention)
70+
invalid_mask = torch.tril(torch.ones_like(scores, dtype=torch.bool), diagonal=0)
71+
scores = torch.where(invalid_mask, scores, torch.zeros_like(scores))
72+
73+
# Compute and store output
74+
output_batch = torch.bmm(scores, v_batch)
75+
output[seq_offsets[i] : seq_offsets[i + 1]] = output_batch.transpose(0, 1)
76+
77+
return output
78+
79+
80+
# %%
81+
# Jagged HSTU Attention Kernel
82+
# ---------------
83+
@helion.kernel()
84+
def _helion_jagged_attention_kernel(
85+
max_seq_len: int,
86+
alpha: float,
87+
q: torch.Tensor,
88+
k: torch.Tensor,
89+
v: torch.Tensor,
90+
seq_offsets: torch.Tensor,
91+
) -> torch.Tensor:
92+
"""Helion implementation of HSTU jagged attention"""
93+
scale = 1.0 / max_seq_len
94+
num_heads = hl.specialize(q.size(1))
95+
num_batches = hl.specialize(seq_offsets.size(0) - 1)
96+
dimV = hl.specialize(v.size(2))
97+
98+
out = torch.zeros_like(v)
99+
100+
# Tile over batch, head, sequence
101+
for tile_b, tile_h, tile_q in hl.tile(
102+
[num_batches, num_heads, max_seq_len], block_size=[1, 1, None]
103+
):
104+
starts = seq_offsets[tile_b.begin]
105+
ends = seq_offsets[tile_b.begin + 1]
106+
seq_len = ends - starts
107+
108+
if tile_q.begin < seq_len:
109+
mask_q = tile_q.index < seq_len
110+
q_blk = q[tile_q.index + starts, tile_h.begin, :]
111+
acc = hl.zeros([tile_q, dimV], dtype=torch.float32)
112+
113+
# Causal attention: only attend to previous tokens
114+
for tile_kv in hl.tile(0, tile_q.end, block_size=None):
115+
mask_kv = tile_kv.index < seq_len
116+
k_blk = k[tile_kv.index + starts, tile_h.begin, :]
117+
v_blk = v[tile_kv.index + starts, tile_h.begin, :]
118+
119+
# Compute attention scores with SiLU activation
120+
scores = (
121+
torch.nn.functional.silu(torch.matmul(q_blk, k_blk.T) * alpha)
122+
* scale
123+
)
124+
125+
# Apply causal mask: only attend to previous positions
126+
scores = torch.where(
127+
(tile_q.index.unsqueeze(1) > tile_kv.index.unsqueeze(0))
128+
& mask_q[:, None]
129+
& mask_kv[None, :],
130+
scores,
131+
0.0,
132+
)
133+
134+
acc += torch.matmul(scores.to(v.dtype), v_blk)
135+
136+
# Store result
137+
out[tile_q.index + starts, tile_h.begin, :] = acc.to(out.dtype)
138+
139+
return out
140+
141+
142+
# %%
143+
# Benchmark Wrapper
144+
# --------------
145+
def ragged_attention_tritonbench(
146+
q: torch.Tensor,
147+
k: torch.Tensor,
148+
v: torch.Tensor,
149+
seq_offsets: torch.Tensor,
150+
num_targets: torch.Tensor | None,
151+
max_seq_len: int,
152+
) -> torch.Tensor:
153+
"""Wrapper function for jagged attention kernel"""
154+
return _helion_jagged_attention_kernel(
155+
max_seq_len=max_seq_len,
156+
alpha=1.0 / v.size(2) ** 2,
157+
q=q,
158+
k=k,
159+
v=v,
160+
seq_offsets=seq_offsets,
161+
)
162+
163+
164+
# %%
165+
# Testing Function
166+
# -------------
167+
def test(
168+
batch_size: int,
169+
max_seq_len: int,
170+
heads: int,
171+
head_dim: int,
172+
dtype: torch.dtype = torch.bfloat16,
173+
device: torch.device | str = "cuda",
174+
) -> None:
175+
"""
176+
Test the jagged HSTU attention kernel implementation.
177+
178+
Args:
179+
batch_size: Number of sequences in the batch
180+
max_seq_len: Maximum sequence length
181+
heads: Number of attention heads
182+
head_dim: Dimension of each attention head
183+
dtype: Data type for the tensors
184+
device: Device to run the test on
185+
"""
186+
device = torch.device(device)
187+
188+
# Generate random sequence lengths
189+
min_seq_len = max_seq_len // 2
190+
seq_lengths = torch.randint(
191+
min_seq_len, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device
192+
)
193+
seq_offsets = torch.cat(
194+
[
195+
torch.tensor([0], dtype=torch.int32, device=device),
196+
torch.cumsum(seq_lengths, dim=0),
197+
]
198+
)
199+
total_seq_len = int(seq_offsets[-1].item())
200+
201+
# q, k, v: [total_seq_len, heads, head_dim]
202+
q = torch.randn(
203+
(total_seq_len, heads, head_dim),
204+
dtype=dtype,
205+
device=device,
206+
requires_grad=True,
207+
)
208+
k = torch.randn(
209+
(total_seq_len, heads, head_dim),
210+
dtype=dtype,
211+
device=device,
212+
requires_grad=True,
213+
)
214+
v = torch.randn(
215+
(total_seq_len, heads, head_dim),
216+
dtype=dtype,
217+
device=device,
218+
requires_grad=True,
219+
)
220+
221+
baselines = {
222+
"torch": reference_jagged_hstu_kernel_pytorch,
223+
}
224+
if HAS_HAMMER:
225+
226+
def _triton_hstu_mha(
227+
q: torch.Tensor,
228+
k: torch.Tensor,
229+
v: torch.Tensor,
230+
seq_offsets: torch.Tensor,
231+
num_targets: torch.Tensor | None,
232+
max_seq_len: int,
233+
) -> torch.Tensor:
234+
return triton_hstu_mha( # pyright: ignore[reportPossiblyUnboundVariable,reportCallIssue]
235+
max_seq_len,
236+
alpha=1.0 / v.size(2) ** 2,
237+
q=q,
238+
k=k,
239+
v=v,
240+
seq_offsets=seq_offsets,
241+
num_targets=num_targets,
242+
max_attn_len=0,
243+
contextual_seq_len=0,
244+
)
245+
246+
baselines["tritonbench"] = _triton_hstu_mha
247+
248+
run_example(
249+
ragged_attention_tritonbench,
250+
baselines,
251+
(q, k, v, seq_offsets, None, max_seq_len),
252+
)
253+
254+
255+
# %%
256+
# Main Function
257+
# -----------
258+
def main() -> None:
259+
"""
260+
Main entry point for testing the simplified jagged HSTU attention kernel.
261+
"""
262+
test(batch_size=1024, max_seq_len=1024, heads=4, head_dim=128, dtype=torch.bfloat16)
263+
264+
265+
if __name__ == "__main__":
266+
main()

test/test_examples.expected

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,110 @@ def jagged_dense_add_2d(x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.
910910
_launcher(_helion_jagged_dense_add_2d, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_data, y, out, y.size(1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), num_rows, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
911911
return out
912912

913+
--- assertExpectedJournal(TestExamples.test_jagged_hstu_attn)
914+
from __future__ import annotations
915+
916+
import torch
917+
import triton
918+
import triton.language as tl
919+
from helion.runtime import default_launcher as _default_launcher
920+
921+
@triton.jit
922+
def _helion__helion_jagged_attention_kernel(seq_offsets, q, k, v, out, k_stride_0, k_stride_1, k_stride_2, out_stride_0, out_stride_1, out_stride_2, q_stride_0, q_stride_1, q_stride_2, seq_offsets_stride_0, v_stride_0, v_stride_1, v_stride_2, max_seq_len, alpha, scale, _BLOCK_SIZE_2: tl.constexpr, _RDIM_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
923+
num_blocks_0 = 4
924+
num_blocks_1 = 8
925+
pid_0 = tl.program_id(0) % num_blocks_0
926+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
927+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
928+
offset_0 = pid_0
929+
offset_1 = pid_1
930+
offset_2 = pid_2 * _BLOCK_SIZE_2
931+
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
932+
mask_2 = indices_2 < max_seq_len
933+
indices_5 = tl.arange(0, _RDIM_SIZE_3).to(tl.int32)
934+
starts = tl.load(seq_offsets + offset_0 * seq_offsets_stride_0, None)
935+
add = 1 + offset_0
936+
ends = tl.load(seq_offsets + add * seq_offsets_stride_0, None)
937+
v_0 = ends - starts
938+
v_1 = v_0 > offset_2
939+
if v_1:
940+
v_0_copy = v_0
941+
starts_copy = starts
942+
v_0_copy_0 = v_0_copy
943+
starts_copy_0 = starts_copy
944+
v_2 = v_0_copy_0[None]
945+
v_3 = v_2.to(tl.int32)
946+
v_4 = indices_2 < v_3
947+
v_5 = starts_copy_0[None]
948+
v_6 = v_5.to(tl.int32)
949+
v_7 = indices_2 + v_6
950+
q_blk = tl.load(q + (v_7[:, None] * q_stride_0 + offset_1 * q_stride_1 + indices_5[None, :] * q_stride_2), mask_2[:, None], other=0)
951+
acc = tl.full([_BLOCK_SIZE_2, 32], 0.0, tl.float32)
952+
tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, max_seq_len)
953+
for offset_3 in tl.range(0, tile_end.to(tl.int32), _BLOCK_SIZE_4):
954+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
955+
mask_4 = indices_3 < tile_end
956+
v_0_copy_0_copy = v_0_copy_0
957+
starts_copy_0_copy = starts_copy_0
958+
q_blk_copy = q_blk
959+
v_4_copy = v_4
960+
acc_copy = acc
961+
v_0_copy_0_copy_0 = v_0_copy_0_copy
962+
starts_copy_0_copy_0 = starts_copy_0_copy
963+
q_blk_copy_0 = q_blk_copy
964+
v_4_copy_0 = v_4_copy
965+
acc_copy_0 = acc_copy
966+
v_8 = v_0_copy_0_copy_0[None]
967+
v_9 = v_8.to(tl.int32)
968+
v_10 = indices_3 < v_9
969+
v_11 = starts_copy_0_copy_0[None]
970+
v_12 = v_11.to(tl.int32)
971+
v_13 = indices_3 + v_12
972+
k_blk = tl.load(k + (v_13[:, None] * k_stride_0 + offset_1 * k_stride_1 + indices_5[None, :] * k_stride_2), mask_4[:, None], other=0)
973+
v_14 = starts_copy_0_copy_0[None]
974+
v_15 = v_14.to(tl.int32)
975+
v_16 = indices_3 + v_15
976+
v_blk = tl.load(v + (v_16[:, None] * v_stride_0 + offset_1 * v_stride_1 + indices_5[None, :] * v_stride_2), mask_4[:, None], other=0)
977+
permute = tl.permute(k_blk, [1, 0])
978+
mm = tl.dot(q_blk_copy_0, permute, input_precision='tf32')
979+
v_17 = alpha.to(tl.bfloat16)
980+
v_18 = mm * v_17
981+
v_19 = v_18.to(tl.float32)
982+
v_20 = tl.sigmoid(v_19)
983+
v_21 = v_19 * v_20
984+
v_22 = v_21.to(tl.bfloat16)
985+
v_23 = scale.to(tl.bfloat16)
986+
v_24 = v_22 * v_23
987+
unsqueeze = indices_2[:, None]
988+
unsqueeze_1 = indices_3[None, :]
989+
v_25 = unsqueeze > unsqueeze_1
990+
subscript = v_4_copy_0[:, None]
991+
v_26 = v_25 & subscript
992+
subscript_1 = v_10[None, :]
993+
v_27 = v_26 & subscript_1
994+
v_28 = tl.full([], 0.0, tl.bfloat16)
995+
v_29 = v_28[None, None]
996+
v_30 = tl.where(v_27, v_24, v_29)
997+
_mask_to_2 = tl.where(mask_2[:, None] & mask_4[None, :], v_30, 0)
998+
mm_1 = tl.dot(_mask_to_2, v_blk, input_precision='tf32')
999+
v_31 = mm_1.to(tl.float32)
1000+
acc = acc_copy_0 + v_31
1001+
v_33 = acc.to(tl.bfloat16)
1002+
v_34 = starts_copy_0[None]
1003+
v_35 = v_34.to(tl.int32)
1004+
v_36 = indices_2 + v_35
1005+
tl.store(out + (v_36[:, None] * out_stride_0 + offset_1 * out_stride_1 + indices_5[None, :] * out_stride_2), v_33, mask_2[:, None])
1006+
1007+
def _helion_jagged_attention_kernel(max_seq_len: int, alpha: float, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_offsets: torch.Tensor, *, _launcher=_default_launcher):
1008+
"""Helion implementation of HSTU jagged attention"""
1009+
scale = 1.0 / max_seq_len
1010+
out = torch.zeros_like(v)
1011+
_BLOCK_SIZE_2 = 16
1012+
_RDIM_SIZE_3 = 32
1013+
_BLOCK_SIZE_4 = 16
1014+
_launcher(_helion__helion_jagged_attention_kernel, (4 * q.size(1) * triton.cdiv(max_seq_len, _BLOCK_SIZE_2),), seq_offsets, q, k, v, out, k.stride(0), k.stride(1), k.stride(2), out.stride(0), out.stride(1), out.stride(2), q.stride(0), q.stride(1), q.stride(2), seq_offsets.stride(0), v.stride(0), v.stride(1), v.stride(2), max_seq_len, alpha, scale, _BLOCK_SIZE_2, _RDIM_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=3)
1015+
return out
1016+
9131017
--- assertExpectedJournal(TestExamples.test_jagged_mean)
9141018
from __future__ import annotations
9151019

0 commit comments

Comments
 (0)