1+ """
2+ Ragged attention implementation in Helion DSL.
3+
4+ This implements HSTU (Hierarchical Sequential Transduction Unit) attention
5+ which handles variable-length sequences using seq_offsets.
6+ """
7+
8+ import torch
9+
10+ import helion
11+ import helion .language as hl
12+
13+
14+ @helion .kernel (
15+ config = helion .Config (
16+ block_sizes = [32 , 128 ],
17+ num_warps = 4 ,
18+ num_stages = 2 ,
19+ ),
20+ static_shapes = True ,
21+ )
22+ def ragged_attention (
23+ q : torch .Tensor , # [total_tokens, num_heads, head_dim]
24+ k : torch .Tensor , # [total_tokens, num_heads, head_dim]
25+ v : torch .Tensor , # [total_tokens, num_heads, head_dim]
26+ seq_offsets : torch .Tensor , # [num_sequences + 1] - cumulative token positions
27+ alpha : float ,
28+ max_seq_len : int ,
29+ ) -> torch .Tensor :
30+ """Ragged attention using SiLU activation with proper sequence masking."""
31+ total_tokens = q .size (0 )
32+ num_heads = q .size (1 )
33+ head_dim = hl .specialize (q .size (2 ))
34+ num_sequences = seq_offsets .size (0 ) - 1
35+
36+ out = torch .zeros_like (q )
37+
38+ # Precompute inverse to avoid fp64 promotion
39+ inv_max_seq_len = 1.0 / float (max_seq_len )
40+
41+ # Process each sequence
42+ for seq_idx in hl .grid (num_sequences ):
43+ # Get sequence boundaries
44+ seq_start = seq_offsets [seq_idx ]
45+ seq_end = seq_offsets [seq_idx + 1 ]
46+ seq_length = seq_end - seq_start
47+
48+ # Skip empty sequences
49+ if seq_length > 0 :
50+ # Process each head independently
51+ for head_idx in hl .grid (num_heads ):
52+ # Tile over query positions in this sequence
53+ for tile_q in hl .tile (seq_start , seq_end ):
54+ # Initialize accumulator for this tile
55+ acc = hl .zeros ([tile_q , head_dim ], dtype = torch .float32 )
56+
57+ # Attend to all key positions in this sequence
58+ for tile_k in hl .tile (seq_start , seq_end ):
59+ # Load Q and K chunks for this head
60+ q_chunk = q [tile_q , head_idx , :] # [tile_q, head_dim]
61+ k_chunk = k [tile_k , head_idx , :] # [tile_k, head_dim]
62+
63+ # Compute attention scores: Q @ K^T
64+ scores = torch .matmul (q_chunk , k_chunk .T ) # [tile_q, tile_k]
65+ scores_scaled = scores * alpha
66+
67+ # Apply SiLU activation: x / (1 + exp(-x))
68+ exp_neg = torch .exp (- scores_scaled )
69+ silu_scores = scores_scaled / (1.0 + exp_neg )
70+
71+ # Scale by 1/max_seq_len
72+ silu_normalized = silu_scores * inv_max_seq_len
73+
74+ # Load V chunk
75+ v_chunk = v [tile_k , head_idx , :] # [tile_k, head_dim]
76+
77+ # Accumulate weighted values
78+ update = torch .matmul (silu_normalized .to (torch .float32 ), v_chunk .to (torch .float32 ))
79+ acc = acc + update
80+
81+ # Store results back
82+ out [tile_q , head_idx , :] = acc .to (out .dtype )
83+
84+ return out
85+
86+
87+ def ragged_attention_tritonbench (
88+ max_seq_len : int ,
89+ alpha : float ,
90+ q : torch .Tensor ,
91+ k : torch .Tensor ,
92+ v : torch .Tensor ,
93+ seq_offsets : torch .Tensor ,
94+ num_targets : torch .Tensor | None ,
95+ max_attn_len : int ,
96+ contextual_seq_len : int ,
97+ sort_by_length : bool ,
98+ ) -> torch .Tensor :
99+ """Wrapper to match tritonbench interface."""
100+ return ragged_attention (q , k , v , seq_offsets , alpha , max_seq_len )
101+
102+
103+ # For tritonbench integration
104+ TRITONBENCH_ARGS = {
105+ "batch_size" : 64 , # Reduced for better performance
106+ "heads" : 4 ,
107+ "min_seq_len_log2" : 8 , # 2^8 = 256
108+ "max_seq_len_log2" : 8 , # 2^8 = 256
109+ }
0 commit comments