Skip to content

Commit 5a92b75

Browse files
committed
[WIP] ragged_attention
1 parent 6765038 commit 5a92b75

File tree

2 files changed

+138
-0
lines changed

2 files changed

+138
-0
lines changed

benchmarks/run.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@
8484
"examples.fused_linear_cross_entropy",
8585
"fused_linear_cross_entropy",
8686
),
87+
"ragged_attention": (
88+
"tritonbench.operators.ragged_attention.operator",
89+
"examples.ragged_attention",
90+
"ragged_attention_tritonbench",
91+
),
8792
}
8893

8994

@@ -306,6 +311,30 @@ def _inner() -> Callable[..., Any]:
306311
if isinstance(attr, Kernel):
307312
attr.settings.force_autotune = True
308313

314+
# Handle special case for ragged_attention which needs additional parameters
315+
if kernel_name == "ragged_attention" and len(args) == 6:
316+
# Extract the 6 arguments from tritonbench
317+
q, k, v, seq_offsets, num_targets, max_seq_len = args
318+
319+
# Convert None num_targets to empty tensor
320+
if num_targets is None:
321+
num_targets = torch.empty(0, dtype=torch.int32, device=q.device)
322+
323+
# Call with the same parameter order and defaults as triton_hstu_mha
324+
# Values taken from the operator.py defaults
325+
return kernel_func(
326+
max_seq_len,
327+
1.0 / q.size(-1), # alpha = 1.0 / attn_dim (from operator.py line 74)
328+
q,
329+
k,
330+
v,
331+
seq_offsets,
332+
num_targets,
333+
0, # max_attn_len (default from operator.py)
334+
0, # contextual_seq_len (default from operator.py)
335+
True, # sort_by_length (default from triton_hstu_mha call)
336+
)
337+
309338
return kernel_func(*args)
310339

311340
return _inner

examples/ragged_attention.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
"""
2+
Ragged attention implementation in Helion DSL.
3+
4+
This implements HSTU (Hierarchical Sequential Transduction Unit) attention
5+
which handles variable-length sequences using seq_offsets.
6+
"""
7+
8+
import torch
9+
10+
import helion
11+
import helion.language as hl
12+
13+
14+
@helion.kernel(
15+
config=helion.Config(
16+
block_sizes=[32, 128],
17+
num_warps=4,
18+
num_stages=2,
19+
),
20+
static_shapes=True,
21+
)
22+
def ragged_attention(
23+
q: torch.Tensor, # [total_tokens, num_heads, head_dim]
24+
k: torch.Tensor, # [total_tokens, num_heads, head_dim]
25+
v: torch.Tensor, # [total_tokens, num_heads, head_dim]
26+
seq_offsets: torch.Tensor, # [num_sequences + 1] - cumulative token positions
27+
alpha: float,
28+
max_seq_len: int,
29+
) -> torch.Tensor:
30+
"""Ragged attention using SiLU activation with proper sequence masking."""
31+
total_tokens = q.size(0)
32+
num_heads = q.size(1)
33+
head_dim = hl.specialize(q.size(2))
34+
num_sequences = seq_offsets.size(0) - 1
35+
36+
out = torch.zeros_like(q)
37+
38+
# Precompute inverse to avoid fp64 promotion
39+
inv_max_seq_len = 1.0 / float(max_seq_len)
40+
41+
# Process each sequence
42+
for seq_idx in hl.grid(num_sequences):
43+
# Get sequence boundaries
44+
seq_start = seq_offsets[seq_idx]
45+
seq_end = seq_offsets[seq_idx + 1]
46+
seq_length = seq_end - seq_start
47+
48+
# Skip empty sequences
49+
if seq_length > 0:
50+
# Process each head independently
51+
for head_idx in hl.grid(num_heads):
52+
# Tile over query positions in this sequence
53+
for tile_q in hl.tile(seq_start, seq_end):
54+
# Initialize accumulator for this tile
55+
acc = hl.zeros([tile_q, head_dim], dtype=torch.float32)
56+
57+
# Attend to all key positions in this sequence
58+
for tile_k in hl.tile(seq_start, seq_end):
59+
# Load Q and K chunks for this head
60+
q_chunk = q[tile_q, head_idx, :] # [tile_q, head_dim]
61+
k_chunk = k[tile_k, head_idx, :] # [tile_k, head_dim]
62+
63+
# Compute attention scores: Q @ K^T
64+
scores = torch.matmul(q_chunk, k_chunk.T) # [tile_q, tile_k]
65+
scores_scaled = scores * alpha
66+
67+
# Apply SiLU activation: x / (1 + exp(-x))
68+
exp_neg = torch.exp(-scores_scaled)
69+
silu_scores = scores_scaled / (1.0 + exp_neg)
70+
71+
# Scale by 1/max_seq_len
72+
silu_normalized = silu_scores * inv_max_seq_len
73+
74+
# Load V chunk
75+
v_chunk = v[tile_k, head_idx, :] # [tile_k, head_dim]
76+
77+
# Accumulate weighted values
78+
update = torch.matmul(silu_normalized.to(torch.float32), v_chunk.to(torch.float32))
79+
acc = acc + update
80+
81+
# Store results back
82+
out[tile_q, head_idx, :] = acc.to(out.dtype)
83+
84+
return out
85+
86+
87+
def ragged_attention_tritonbench(
88+
max_seq_len: int,
89+
alpha: float,
90+
q: torch.Tensor,
91+
k: torch.Tensor,
92+
v: torch.Tensor,
93+
seq_offsets: torch.Tensor,
94+
num_targets: torch.Tensor | None,
95+
max_attn_len: int,
96+
contextual_seq_len: int,
97+
sort_by_length: bool,
98+
) -> torch.Tensor:
99+
"""Wrapper to match tritonbench interface."""
100+
return ragged_attention(q, k, v, seq_offsets, alpha, max_seq_len)
101+
102+
103+
# For tritonbench integration
104+
TRITONBENCH_ARGS = {
105+
"batch_size": 64, # Reduced for better performance
106+
"heads": 4,
107+
"min_seq_len_log2": 8, # 2^8 = 256
108+
"max_seq_len_log2": 8, # 2^8 = 256
109+
}

0 commit comments

Comments
 (0)