Skip to content

Commit 47485d6

Browse files
committed
jagged hstu example
ghstack-source-id: b085551 Pull Request resolved: #527
1 parent 3c0348a commit 47485d6

File tree

3 files changed

+297
-2
lines changed

3 files changed

+297
-2
lines changed

examples/jagged_hstu_attn.py

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
"""
2+
Simplified Jagged HSTU Attention Forward Example
3+
===============================================
4+
5+
This example demonstrates a simplified version of jagged HSTU attention using Helion.
6+
"""
7+
8+
# %%
9+
# Imports
10+
# -------
11+
from __future__ import annotations
12+
13+
from typing import Any
14+
15+
import torch
16+
17+
import helion
18+
from helion._testing import run_example
19+
import helion.language as hl
20+
21+
try:
22+
from generative_recommenders.ops.triton.triton_hstu_attention import (
23+
triton_hstu_mha, # pyright: ignore[reportMissingImports]
24+
)
25+
26+
HAS_HAMMER = True
27+
except ImportError:
28+
HAS_HAMMER = False
29+
30+
31+
def generate_inputs() -> dict[str, Any]:
32+
"""Generate small inputs for HSTU attention for easier verification and testing"""
33+
34+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35+
dtype = torch.bfloat16
36+
37+
batch_size = 1024
38+
max_seq_len = 1024 # N parameter
39+
heads = 4
40+
head_dim = 128
41+
42+
# Generate random sequence lengths
43+
min_seq_len = max_seq_len // 2
44+
seq_lengths = torch.randint(
45+
min_seq_len, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device
46+
)
47+
seq_offsets = torch.cat(
48+
[
49+
torch.tensor([0], dtype=torch.int32, device=device),
50+
torch.cumsum(seq_lengths, dim=0),
51+
]
52+
)
53+
total_seq_len = int(seq_offsets[-1].item())
54+
55+
# Generate tensors with ragged sequence length
56+
# q, k, v: [total_seq_len, heads, head_dim]
57+
q = torch.randn(
58+
(total_seq_len, heads, head_dim),
59+
dtype=dtype,
60+
device=device,
61+
requires_grad=True,
62+
)
63+
k = torch.randn(
64+
(total_seq_len, heads, head_dim),
65+
dtype=dtype,
66+
device=device,
67+
requires_grad=True,
68+
)
69+
v = torch.randn(
70+
(total_seq_len, heads, head_dim),
71+
dtype=dtype,
72+
device=device,
73+
requires_grad=True,
74+
)
75+
76+
# Parameters
77+
alpha = 1.0 / (head_dim**0.5) # Scaling factor
78+
invalid_attn_mask_type = "lower_triangular"
79+
80+
# Optional parameters (set to None for simplicity)
81+
num_targets = None
82+
attn_scale = None
83+
attn_bias = None
84+
seq2_offsets = None
85+
86+
# Integer parameters
87+
max_attn_len = 0
88+
contextual_seq_len = 0
89+
sort_by_length = False
90+
full_attn_size = 0
91+
92+
return {
93+
"N": max_seq_len,
94+
"alpha": alpha,
95+
"q": q,
96+
"k": k,
97+
"v": v,
98+
"seq_offsets": seq_offsets,
99+
"invalid_attn_mask_type": invalid_attn_mask_type,
100+
"num_targets": num_targets,
101+
"attn_scale": attn_scale,
102+
"attn_bias": attn_bias,
103+
"seq2_offsets": seq2_offsets,
104+
"max_attn_len": max_attn_len,
105+
"contextual_seq_len": contextual_seq_len,
106+
"sort_by_length": sort_by_length,
107+
"full_attn_size": full_attn_size,
108+
}
109+
110+
111+
def reference_jagged_hstu_kernel_pytorch(inputs: dict[str, Any]) -> torch.Tensor:
112+
"""Simple PyTorch implementation of HSTU ragged attention using direct tensor slicing"""
113+
q = inputs["q"]
114+
k = inputs["k"]
115+
v = inputs["v"]
116+
seq_offsets = inputs["seq_offsets"]
117+
alpha = inputs["alpha"]
118+
N = inputs["N"]
119+
120+
# Initialize output
121+
output = torch.zeros_like(v)
122+
123+
# Scale factor
124+
scale = 1.0 / N
125+
126+
# Compute per-batch sequence lengths
127+
seq_lens = seq_offsets[1:] - seq_offsets[:-1]
128+
129+
q_split = torch.split(q, seq_lens.tolist(), dim=0)
130+
k_split = torch.split(k, seq_lens.tolist(), dim=0)
131+
v_split = torch.split(v, seq_lens.tolist(), dim=0)
132+
133+
# Process each sequence in the batch using direct tensor slicing
134+
for i, (q_batch, k_batch, v_batch) in enumerate(
135+
zip(q_split, k_split, v_split, strict=False)
136+
):
137+
q_batch = q_batch.transpose(0, 1) # [heads, seq_len, head_dim]
138+
k_batch = k_batch.permute(1, 2, 0) # [heads, head_dim, seq_len]
139+
v_batch = v_batch.transpose(0, 1) # [heads, seq_len, head_dim]
140+
141+
# Compute attention scores using batch matrix multiplication
142+
scores = torch.bmm(q_batch, k_batch) * alpha
143+
144+
# Apply SiLU activation
145+
scores = (scores / (1.0 + torch.exp(-scores))) * scale
146+
147+
# Apply invalid mask
148+
if inputs["invalid_attn_mask_type"] == "lower_triangular":
149+
invalid_mask = torch.tril(
150+
torch.ones_like(scores, dtype=torch.bool), diagonal=0
151+
)
152+
scores = torch.where(invalid_mask, scores, torch.zeros_like(scores))
153+
154+
# Compute and store output
155+
output_batch = torch.bmm(scores, v_batch)
156+
output[seq_offsets[i] : seq_offsets[i + 1]] = output_batch.transpose(0, 1)
157+
158+
return output
159+
160+
161+
@helion.kernel()
162+
def _helion_ragged_attention_kernel(
163+
q: torch.Tensor,
164+
k: torch.Tensor,
165+
v: torch.Tensor,
166+
seq_offsets: torch.Tensor,
167+
alpha: float,
168+
invalid_mask_type: str,
169+
max_seq_len_tensor: torch.Tensor,
170+
) -> torch.Tensor:
171+
max_seq_len = max_seq_len_tensor.numel()
172+
scale = 1.0 / max_seq_len
173+
174+
num_heads = hl.specialize(q.size(1))
175+
num_batches = hl.specialize(seq_offsets.size(0) - 1)
176+
dimV = hl.specialize(v.size(2))
177+
178+
out = torch.zeros_like(v)
179+
180+
# --- Tile over batch batch, head, sequence ---
181+
for tile_b, tile_h, tile_q in hl.tile(
182+
[num_batches, num_heads, max_seq_len], block_size=[1, 1, None]
183+
):
184+
starts = seq_offsets[tile_b.begin]
185+
ends = seq_offsets[tile_b.begin + 1]
186+
seq_len = ends - starts
187+
188+
if tile_q.begin < seq_len:
189+
mask_q = tile_q.index < seq_len
190+
q_blk = q[tile_q.index + starts, tile_h.begin, :]
191+
acc = hl.zeros([tile_q, dimV], dtype=torch.float32)
192+
193+
if invalid_mask_type == "lower_triangular":
194+
low = 0
195+
high = tile_q.end
196+
else:
197+
low = 0
198+
high = seq_len
199+
200+
for tile_kv in hl.tile(low, high, block_size=None):
201+
mask_kv = tile_kv.index < seq_len
202+
k_blk = k[tile_kv.index + starts, tile_h.begin, :]
203+
v_blk = v[tile_kv.index + starts, tile_h.begin, :]
204+
205+
scores = (
206+
torch.nn.functional.silu(torch.matmul(q_blk, k_blk.T) * alpha)
207+
* scale
208+
)
209+
210+
if invalid_mask_type == "lower_triangular":
211+
scores = torch.where(
212+
(tile_q.index.unsqueeze(1) > tile_kv.index.unsqueeze(0))
213+
& mask_q[:, None]
214+
& mask_kv[None, :],
215+
scores,
216+
0.0,
217+
)
218+
219+
acc += torch.matmul(scores.to(v.dtype), v_blk)
220+
221+
# Store result
222+
out[tile_q.index + starts, tile_h.begin, :] = acc
223+
224+
return out
225+
226+
227+
def helion_ragged_attention_function(inputs: dict[str, Any]) -> torch.Tensor:
228+
"""
229+
Wrapper function for the Helion ragged attention kernel.
230+
"""
231+
232+
return _helion_ragged_attention_kernel(
233+
q=inputs["q"],
234+
k=inputs["k"],
235+
v=inputs["v"],
236+
seq_offsets=inputs["seq_offsets"],
237+
alpha=inputs["alpha"],
238+
invalid_mask_type=inputs["invalid_attn_mask_type"],
239+
max_seq_len_tensor=torch.empty(inputs["N"], device=inputs["q"].device),
240+
)
241+
242+
243+
def tritonbench_hstu_attention_function(inputs: dict[str, Any]) -> torch.Tensor:
244+
"""
245+
Wrapper function for the tritonbench HSTU attention implementation.
246+
247+
Args:
248+
inputs: Dictionary containing all the input parameters
249+
250+
Returns:
251+
Output tensor from tritonbench HSTU attention
252+
"""
253+
if not HAS_HAMMER:
254+
# Return a dummy tensor with the same shape as expected output
255+
return torch.zeros_like(inputs["v"])
256+
257+
return triton_hstu_mha( # pyright: ignore[reportCallIssue,reportPossiblyUnboundVariable]
258+
N=inputs["N"],
259+
alpha=inputs["alpha"],
260+
q=inputs["q"],
261+
k=inputs["k"],
262+
v=inputs["v"],
263+
seq_offsets=inputs["seq_offsets"],
264+
num_targets=inputs["num_targets"],
265+
max_attn_len=inputs["max_attn_len"],
266+
contextual_seq_len=inputs["contextual_seq_len"],
267+
sort_by_length=inputs["sort_by_length"],
268+
)
269+
270+
271+
def main() -> None:
272+
"""
273+
Main entry point for testing the simplified jagged HSTU attention kernel.
274+
"""
275+
inputs = generate_inputs()
276+
277+
baselines = {
278+
"torch": lambda inputs: reference_jagged_hstu_kernel_pytorch(inputs),
279+
}
280+
if HAS_HAMMER:
281+
baselines["tritonbench"] = lambda inputs: tritonbench_hstu_attention_function(
282+
inputs
283+
)
284+
285+
run_example(
286+
lambda inputs: helion_ragged_attention_function(inputs), baselines, (inputs,)
287+
)
288+
289+
290+
if __name__ == "__main__":
291+
main()

helion/_testing.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,9 @@ def setUp(self) -> None:
589589
super().setUp()
590590
self._test_stack = contextlib.ExitStack()
591591

592-
from torch._inductor.utils import fresh_cache
592+
from torch._inductor.utils import (
593+
fresh_cache, # pyright: ignore[reportAttributeAccessIssue]
594+
)
593595

594596
self._test_stack.enter_context(fresh_cache())
595597

helion/autotuner/base_cache.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def torch_key_wrapper() -> str:
6565

6666
@functools.cache
6767
def triton_key_wrapper() -> str:
68-
from torch._inductor.runtime.triton_compat import triton_key
68+
from torch._inductor.runtime.triton_compat import (
69+
triton_key, # pyright: ignore[reportAttributeAccessIssue]
70+
)
6971

7072
return triton_key()
7173

0 commit comments

Comments
 (0)