Skip to content

Commit e10f585

Browse files
committed
jagged hstu example
ghstack-source-id: c0da2a9 Pull Request resolved: #527
1 parent c4126b4 commit e10f585

File tree

4 files changed

+423
-0
lines changed

4 files changed

+423
-0
lines changed

benchmarks/run.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@ class RunResult:
6161
KERNEL_MAPPINGS: dict[str, tuple[str, ...]] = { # pyright: ignore[reportAssignmentType]
6262
# <tritonbench_op_name>: (<tritonbench_module_path>, <helion_kernel_module_path>, <helion_kernel_function_name>)
6363
"vector_add": ("tritonbench.operators.vector_add.operator", "examples.add", "add"),
64+
"ragged_attention": (
65+
"tritonbench.operators.ragged_attention.operator",
66+
"examples.jagged_hstu_attn",
67+
"jagged_attention_wrapper",
68+
{"target_size": 0},
69+
),
6470
"embedding": (
6571
"tritonbench.operators.embedding.operator",
6672
"examples.embedding",

examples/jagged_hstu_attn.py

Lines changed: 251 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
"""
2+
Simplified Jagged HSTU Attention Forward Example
3+
===============================================
4+
5+
This example demonstrates a simplified version of jagged HSTU attention using Helion.
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
import torch
14+
15+
import helion
16+
from helion._testing import run_example
17+
import helion.language as hl
18+
19+
try:
20+
from generative_recommenders.ops.triton.triton_hstu_attention import ( # pyright: ignore[reportMissingImports]
21+
triton_hstu_mha,
22+
)
23+
24+
HAS_HAMMER = True
25+
except ImportError:
26+
HAS_HAMMER = False
27+
28+
29+
def reference_jagged_hstu_kernel_pytorch(
30+
q: torch.Tensor,
31+
k: torch.Tensor,
32+
v: torch.Tensor,
33+
seq_offsets: torch.Tensor,
34+
num_targets: torch.Tensor | None,
35+
max_seq_len: int,
36+
) -> torch.Tensor:
37+
"""Simple PyTorch implementation of HSTU jagged attention"""
38+
# Initialize output
39+
output = torch.zeros_like(v)
40+
41+
# Scale factor
42+
scale = 1.0 / max_seq_len
43+
alpha = 1.0 / v.size(2) ** 2
44+
45+
# Compute per-batch sequence lengths
46+
seq_lens = seq_offsets[1:] - seq_offsets[:-1]
47+
48+
q_split = torch.split(q, seq_lens.tolist(), dim=0)
49+
k_split = torch.split(k, seq_lens.tolist(), dim=0)
50+
v_split = torch.split(v, seq_lens.tolist(), dim=0)
51+
52+
# Get the batches
53+
for i, (q_batch, k_batch, v_batch) in enumerate(
54+
zip(q_split, k_split, v_split, strict=False)
55+
):
56+
q_batch = q_batch.transpose(0, 1) # [heads, seq_len, head_dim]
57+
k_batch = k_batch.permute(1, 2, 0) # [heads, head_dim, seq_len]
58+
v_batch = v_batch.transpose(0, 1) # [heads, seq_len, head_dim]
59+
60+
# Compute attention scores using batch matrix multiplication
61+
scores = torch.bmm(q_batch, k_batch) * alpha
62+
63+
# Apply SiLU activation
64+
scores = (scores / (1.0 + torch.exp(-scores))) * scale
65+
66+
# Apply lower triangular mask (causal attention)
67+
invalid_mask = torch.tril(torch.ones_like(scores, dtype=torch.bool), diagonal=0)
68+
scores = torch.where(invalid_mask, scores, torch.zeros_like(scores))
69+
70+
# Compute and store output
71+
output_batch = torch.bmm(scores, v_batch)
72+
output[seq_offsets[i] : seq_offsets[i + 1]] = output_batch.transpose(0, 1)
73+
74+
return output
75+
76+
77+
@helion.kernel()
78+
def _helion_jagged_attention_kernel(
79+
max_seq_len: int,
80+
alpha: float,
81+
q: torch.Tensor,
82+
k: torch.Tensor,
83+
v: torch.Tensor,
84+
seq_offsets: torch.Tensor,
85+
) -> torch.Tensor:
86+
"""Helion implementation of HSTU jagged attention"""
87+
scale = 1.0 / max_seq_len
88+
num_heads = hl.specialize(q.size(1))
89+
num_batches = hl.specialize(seq_offsets.size(0) - 1)
90+
dimV = hl.specialize(v.size(2))
91+
92+
out = torch.zeros_like(v)
93+
94+
# Tile over batch, head, sequence
95+
for tile_b, tile_h, tile_q in hl.tile(
96+
[num_batches, num_heads, max_seq_len], block_size=[1, 1, None]
97+
):
98+
starts = seq_offsets[tile_b.begin]
99+
ends = seq_offsets[tile_b.begin + 1]
100+
seq_len = ends - starts
101+
102+
if tile_q.begin < seq_len:
103+
mask_q = tile_q.index < seq_len
104+
q_blk = q[tile_q.index + starts, tile_h.begin, :]
105+
acc = hl.zeros([tile_q, dimV], dtype=torch.float32)
106+
107+
# Causal attention: only attend to previous tokens
108+
for tile_kv in hl.tile(0, tile_q.end, block_size=None):
109+
mask_kv = tile_kv.index < seq_len
110+
k_blk = k[tile_kv.index + starts, tile_h.begin, :]
111+
v_blk = v[tile_kv.index + starts, tile_h.begin, :]
112+
113+
# Compute attention scores with SiLU activation
114+
scores = (
115+
torch.nn.functional.silu(torch.matmul(q_blk, k_blk.T) * alpha)
116+
* scale
117+
)
118+
119+
# Apply causal mask: only attend to previous positions
120+
scores = torch.where(
121+
(tile_q.index.unsqueeze(1) > tile_kv.index.unsqueeze(0))
122+
& mask_q[:, None]
123+
& mask_kv[None, :],
124+
scores,
125+
0.0,
126+
)
127+
128+
acc += torch.matmul(scores.to(v.dtype), v_blk)
129+
130+
# Store result
131+
out[tile_q.index + starts, tile_h.begin, :] = acc.to(out.dtype)
132+
133+
return out
134+
135+
136+
def jagged_attention_wrapper(
137+
q: torch.Tensor,
138+
k: torch.Tensor,
139+
v: torch.Tensor,
140+
seq_offsets: torch.Tensor,
141+
num_targets: torch.Tensor | None,
142+
max_seq_len: int,
143+
) -> torch.Tensor:
144+
"""Wrapper function for jagged attention kernel"""
145+
return _helion_jagged_attention_kernel(
146+
max_seq_len=max_seq_len,
147+
alpha=1.0 / v.size(2) ** 2,
148+
q=q,
149+
k=k,
150+
v=v,
151+
seq_offsets=seq_offsets,
152+
)
153+
154+
155+
def test(
156+
batch_size: int,
157+
max_seq_len: int,
158+
heads: int,
159+
head_dim: int,
160+
dtype: torch.dtype = torch.bfloat16,
161+
device: torch.device | str = "cuda",
162+
) -> None:
163+
"""
164+
Test the jagged HSTU attention kernel implementation.
165+
166+
Args:
167+
batch_size: Number of sequences in the batch
168+
max_seq_len: Maximum sequence length
169+
heads: Number of attention heads
170+
head_dim: Dimension of each attention head
171+
dtype: Data type for the tensors
172+
device: Device to run the test on
173+
"""
174+
device = torch.device(device)
175+
176+
# Generate random sequence lengths
177+
min_seq_len = max_seq_len // 2
178+
seq_lengths = torch.randint(
179+
min_seq_len, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device
180+
)
181+
seq_offsets = torch.cat(
182+
[
183+
torch.tensor([0], dtype=torch.int32, device=device),
184+
torch.cumsum(seq_lengths, dim=0),
185+
]
186+
)
187+
total_seq_len = int(seq_offsets[-1].item())
188+
189+
# q, k, v: [total_seq_len, heads, head_dim]
190+
q = torch.randn(
191+
(total_seq_len, heads, head_dim),
192+
dtype=dtype,
193+
device=device,
194+
requires_grad=True,
195+
)
196+
k = torch.randn(
197+
(total_seq_len, heads, head_dim),
198+
dtype=dtype,
199+
device=device,
200+
requires_grad=True,
201+
)
202+
v = torch.randn(
203+
(total_seq_len, heads, head_dim),
204+
dtype=dtype,
205+
device=device,
206+
requires_grad=True,
207+
)
208+
209+
baselines = {
210+
"torch": reference_jagged_hstu_kernel_pytorch,
211+
}
212+
if HAS_HAMMER:
213+
214+
def _triton_hstu_mha(
215+
q: torch.Tensor,
216+
k: torch.Tensor,
217+
v: torch.Tensor,
218+
seq_offsets: torch.Tensor,
219+
num_targets: torch.Tensor | None,
220+
max_seq_len: int,
221+
) -> torch.Tensor:
222+
return triton_hstu_mha( # pyright: ignore[reportPossiblyUnboundVariable,reportCallIssue]
223+
max_seq_len,
224+
alpha=1.0 / v.size(2) ** 2,
225+
q=q,
226+
k=k,
227+
v=v,
228+
seq_offsets=seq_offsets,
229+
num_targets=num_targets,
230+
max_attn_len=0,
231+
contextual_seq_len=0,
232+
)
233+
234+
baselines["tritonbench"] = _triton_hstu_mha
235+
236+
run_example(
237+
jagged_attention_wrapper,
238+
baselines,
239+
(q, k, v, seq_offsets, None, max_seq_len),
240+
)
241+
242+
243+
def main() -> None:
244+
"""
245+
Main entry point for testing the simplified jagged HSTU attention kernel.
246+
"""
247+
test(batch_size=1024, max_seq_len=1024, heads=4, head_dim=128, dtype=torch.bfloat16)
248+
249+
250+
if __name__ == "__main__":
251+
main()

test/test_examples.expected

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,110 @@ def jagged_dense_add_2d(x_data: torch.Tensor, x_offsets: torch.Tensor, y: torch.
910910
_launcher(_helion_jagged_dense_add_2d, (triton.cdiv(num_rows, _BLOCK_SIZE_0),), x_offsets, x_data, y, out, y.size(1), out.stride(0), out.stride(1), x_data.stride(0), x_offsets.stride(0), y.stride(0), y.stride(1), num_rows, _BLOCK_SIZE_0, _BLOCK_SIZE_1, _BLOCK_SIZE_2, num_warps=4, num_stages=3)
911911
return out
912912

913+
--- assertExpectedJournal(TestExamples.test_jagged_hstu_attn)
914+
from __future__ import annotations
915+
916+
import torch
917+
import triton
918+
import triton.language as tl
919+
from helion.runtime import default_launcher as _default_launcher
920+
921+
@triton.jit
922+
def _helion__helion_jagged_attention_kernel(seq_offsets, q, k, v, out, k_stride_0, k_stride_1, k_stride_2, out_stride_0, out_stride_1, out_stride_2, q_stride_0, q_stride_1, q_stride_2, seq_offsets_stride_0, v_stride_0, v_stride_1, v_stride_2, max_seq_len, alpha, scale, _BLOCK_SIZE_2: tl.constexpr, _RDIM_SIZE_3: tl.constexpr, _BLOCK_SIZE_4: tl.constexpr):
923+
num_blocks_0 = 4
924+
num_blocks_1 = 8
925+
pid_0 = tl.program_id(0) % num_blocks_0
926+
pid_1 = tl.program_id(0) // num_blocks_0 % num_blocks_1
927+
pid_2 = tl.program_id(0) // (num_blocks_0 * num_blocks_1)
928+
offset_0 = pid_0
929+
offset_1 = pid_1
930+
offset_2 = pid_2 * _BLOCK_SIZE_2
931+
indices_2 = (offset_2 + tl.arange(0, _BLOCK_SIZE_2)).to(tl.int32)
932+
mask_2 = indices_2 < max_seq_len
933+
indices_5 = tl.arange(0, _RDIM_SIZE_3).to(tl.int32)
934+
starts = tl.load(seq_offsets + offset_0 * seq_offsets_stride_0, None)
935+
add = 1 + offset_0
936+
ends = tl.load(seq_offsets + add * seq_offsets_stride_0, None)
937+
v_0 = ends - starts
938+
v_1 = v_0 > offset_2
939+
if v_1:
940+
v_0_copy = v_0
941+
starts_copy = starts
942+
v_0_copy_0 = v_0_copy
943+
starts_copy_0 = starts_copy
944+
v_2 = v_0_copy_0[None]
945+
v_3 = v_2.to(tl.int32)
946+
v_4 = indices_2 < v_3
947+
v_5 = starts_copy_0[None]
948+
v_6 = v_5.to(tl.int32)
949+
v_7 = indices_2 + v_6
950+
q_blk = tl.load(q + (v_7[:, None] * q_stride_0 + offset_1 * q_stride_1 + indices_5[None, :] * q_stride_2), mask_2[:, None], other=0)
951+
acc = tl.full([_BLOCK_SIZE_2, 32], 0.0, tl.float32)
952+
tile_end = tl.minimum(offset_2 + _BLOCK_SIZE_2, max_seq_len)
953+
for offset_3 in tl.range(0, tile_end.to(tl.int32), _BLOCK_SIZE_4):
954+
indices_3 = offset_3 + tl.arange(0, _BLOCK_SIZE_4).to(tl.int32)
955+
mask_4 = indices_3 < tile_end
956+
v_0_copy_0_copy = v_0_copy_0
957+
starts_copy_0_copy = starts_copy_0
958+
q_blk_copy = q_blk
959+
v_4_copy = v_4
960+
acc_copy = acc
961+
v_0_copy_0_copy_0 = v_0_copy_0_copy
962+
starts_copy_0_copy_0 = starts_copy_0_copy
963+
q_blk_copy_0 = q_blk_copy
964+
v_4_copy_0 = v_4_copy
965+
acc_copy_0 = acc_copy
966+
v_8 = v_0_copy_0_copy_0[None]
967+
v_9 = v_8.to(tl.int32)
968+
v_10 = indices_3 < v_9
969+
v_11 = starts_copy_0_copy_0[None]
970+
v_12 = v_11.to(tl.int32)
971+
v_13 = indices_3 + v_12
972+
k_blk = tl.load(k + (v_13[:, None] * k_stride_0 + offset_1 * k_stride_1 + indices_5[None, :] * k_stride_2), mask_4[:, None], other=0)
973+
v_14 = starts_copy_0_copy_0[None]
974+
v_15 = v_14.to(tl.int32)
975+
v_16 = indices_3 + v_15
976+
v_blk = tl.load(v + (v_16[:, None] * v_stride_0 + offset_1 * v_stride_1 + indices_5[None, :] * v_stride_2), mask_4[:, None], other=0)
977+
permute = tl.permute(k_blk, [1, 0])
978+
mm = tl.dot(q_blk_copy_0, permute, input_precision='tf32')
979+
v_17 = alpha.to(tl.bfloat16)
980+
v_18 = mm * v_17
981+
v_19 = v_18.to(tl.float32)
982+
v_20 = tl.sigmoid(v_19)
983+
v_21 = v_19 * v_20
984+
v_22 = v_21.to(tl.bfloat16)
985+
v_23 = scale.to(tl.bfloat16)
986+
v_24 = v_22 * v_23
987+
unsqueeze = indices_2[:, None]
988+
unsqueeze_1 = indices_3[None, :]
989+
v_25 = unsqueeze > unsqueeze_1
990+
subscript = v_4_copy_0[:, None]
991+
v_26 = v_25 & subscript
992+
subscript_1 = v_10[None, :]
993+
v_27 = v_26 & subscript_1
994+
v_28 = tl.full([], 0.0, tl.bfloat16)
995+
v_29 = v_28[None, None]
996+
v_30 = tl.where(v_27, v_24, v_29)
997+
_mask_to_2 = tl.where(mask_2[:, None] & mask_4[None, :], v_30, 0)
998+
mm_1 = tl.dot(_mask_to_2, v_blk, input_precision='tf32')
999+
v_31 = mm_1.to(tl.float32)
1000+
acc = acc_copy_0 + v_31
1001+
v_33 = starts_copy_0[None]
1002+
v_34 = v_33.to(tl.int32)
1003+
v_35 = indices_2 + v_34
1004+
v_36 = acc.to(tl.bfloat16)
1005+
tl.store(out + (v_35[:, None] * out_stride_0 + offset_1 * out_stride_1 + indices_5[None, :] * out_stride_2), v_36, mask_2[:, None])
1006+
1007+
def _helion_jagged_attention_kernel(max_seq_len: int, alpha: float, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, seq_offsets: torch.Tensor, *, _launcher=_default_launcher):
1008+
"""Helion implementation of HSTU jagged attention"""
1009+
scale = 1.0 / max_seq_len
1010+
out = torch.zeros_like(v)
1011+
_BLOCK_SIZE_2 = 16
1012+
_RDIM_SIZE_3 = 32
1013+
_BLOCK_SIZE_4 = 16
1014+
_launcher(_helion__helion_jagged_attention_kernel, (4 * q.size(1) * triton.cdiv(max_seq_len, _BLOCK_SIZE_2),), seq_offsets, q, k, v, out, k.stride(0), k.stride(1), k.stride(2), out.stride(0), out.stride(1), out.stride(2), q.stride(0), q.stride(1), q.stride(2), seq_offsets.stride(0), v.stride(0), v.stride(1), v.stride(2), max_seq_len, alpha, scale, _BLOCK_SIZE_2, _RDIM_SIZE_3, _BLOCK_SIZE_4, num_warps=4, num_stages=3)
1015+
return out
1016+
9131017
--- assertExpectedJournal(TestExamples.test_jagged_mean)
9141018
from __future__ import annotations
9151019

0 commit comments

Comments
 (0)