|
| 1 | +""" |
| 2 | +Simplified Jagged HSTU Attention Forward Example |
| 3 | +=============================================== |
| 4 | +
|
| 5 | +This example demonstrates a simplified version of jagged HSTU attention using Helion. |
| 6 | +""" |
| 7 | + |
| 8 | +# %% |
| 9 | +# Imports |
| 10 | +# ------- |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +import torch |
| 14 | + |
| 15 | +import helion |
| 16 | +from helion._testing import run_example |
| 17 | +import helion.language as hl |
| 18 | + |
| 19 | +try: |
| 20 | + from generative_recommenders.ops.triton.triton_hstu_attention import ( # pyright: ignore[reportMissingImports] |
| 21 | + triton_hstu_mha, |
| 22 | + ) |
| 23 | + |
| 24 | + HAS_HAMMER = True |
| 25 | +except ImportError: |
| 26 | + HAS_HAMMER = False |
| 27 | + |
| 28 | + |
| 29 | +def reference_jagged_hstu_kernel_pytorch( |
| 30 | + q: torch.Tensor, |
| 31 | + k: torch.Tensor, |
| 32 | + v: torch.Tensor, |
| 33 | + seq_offsets: torch.Tensor, |
| 34 | + num_targets: torch.Tensor | None, |
| 35 | + max_seq_len: int, |
| 36 | +) -> torch.Tensor: |
| 37 | + """Simple PyTorch implementation of HSTU jagged attention""" |
| 38 | + # Initialize output |
| 39 | + output = torch.zeros_like(v) |
| 40 | + |
| 41 | + # Scale factor |
| 42 | + scale = 1.0 / max_seq_len |
| 43 | + alpha = 1.0 / v.size(2) ** 2 |
| 44 | + |
| 45 | + # Compute per-batch sequence lengths |
| 46 | + seq_lens = seq_offsets[1:] - seq_offsets[:-1] |
| 47 | + |
| 48 | + q_split = torch.split(q, seq_lens.tolist(), dim=0) |
| 49 | + k_split = torch.split(k, seq_lens.tolist(), dim=0) |
| 50 | + v_split = torch.split(v, seq_lens.tolist(), dim=0) |
| 51 | + |
| 52 | + # Get the batches |
| 53 | + for i, (q_batch, k_batch, v_batch) in enumerate( |
| 54 | + zip(q_split, k_split, v_split, strict=False) |
| 55 | + ): |
| 56 | + q_batch = q_batch.transpose(0, 1) # [heads, seq_len, head_dim] |
| 57 | + k_batch = k_batch.permute(1, 2, 0) # [heads, head_dim, seq_len] |
| 58 | + v_batch = v_batch.transpose(0, 1) # [heads, seq_len, head_dim] |
| 59 | + |
| 60 | + # Compute attention scores using batch matrix multiplication |
| 61 | + scores = torch.bmm(q_batch, k_batch) * alpha |
| 62 | + |
| 63 | + # Apply SiLU activation |
| 64 | + scores = (scores / (1.0 + torch.exp(-scores))) * scale |
| 65 | + |
| 66 | + # Apply lower triangular mask (causal attention) |
| 67 | + invalid_mask = torch.tril(torch.ones_like(scores, dtype=torch.bool), diagonal=0) |
| 68 | + scores = torch.where(invalid_mask, scores, torch.zeros_like(scores)) |
| 69 | + |
| 70 | + # Compute and store output |
| 71 | + output_batch = torch.bmm(scores, v_batch) |
| 72 | + output[seq_offsets[i] : seq_offsets[i + 1]] = output_batch.transpose(0, 1) |
| 73 | + |
| 74 | + return output |
| 75 | + |
| 76 | + |
| 77 | +@helion.kernel() |
| 78 | +def _helion_jagged_attention_kernel( |
| 79 | + max_seq_len: int, |
| 80 | + alpha: float, |
| 81 | + q: torch.Tensor, |
| 82 | + k: torch.Tensor, |
| 83 | + v: torch.Tensor, |
| 84 | + seq_offsets: torch.Tensor, |
| 85 | +) -> torch.Tensor: |
| 86 | + """Helion implementation of HSTU jagged attention""" |
| 87 | + scale = 1.0 / max_seq_len |
| 88 | + num_heads = hl.specialize(q.size(1)) |
| 89 | + num_batches = hl.specialize(seq_offsets.size(0) - 1) |
| 90 | + dimV = hl.specialize(v.size(2)) |
| 91 | + |
| 92 | + out = torch.zeros_like(v) |
| 93 | + |
| 94 | + # Tile over batch, head, sequence |
| 95 | + for tile_b, tile_h, tile_q in hl.tile( |
| 96 | + [num_batches, num_heads, max_seq_len], block_size=[1, 1, None] |
| 97 | + ): |
| 98 | + starts = seq_offsets[tile_b.begin] |
| 99 | + ends = seq_offsets[tile_b.begin + 1] |
| 100 | + seq_len = ends - starts |
| 101 | + |
| 102 | + if tile_q.begin < seq_len: |
| 103 | + mask_q = tile_q.index < seq_len |
| 104 | + q_blk = q[tile_q.index + starts, tile_h.begin, :] |
| 105 | + acc = hl.zeros([tile_q, dimV], dtype=torch.float32) |
| 106 | + |
| 107 | + # Causal attention: only attend to previous tokens |
| 108 | + for tile_kv in hl.tile(0, tile_q.end, block_size=None): |
| 109 | + mask_kv = tile_kv.index < seq_len |
| 110 | + k_blk = k[tile_kv.index + starts, tile_h.begin, :] |
| 111 | + v_blk = v[tile_kv.index + starts, tile_h.begin, :] |
| 112 | + |
| 113 | + # Compute attention scores with SiLU activation |
| 114 | + scores = ( |
| 115 | + torch.nn.functional.silu(torch.matmul(q_blk, k_blk.T) * alpha) |
| 116 | + * scale |
| 117 | + ) |
| 118 | + |
| 119 | + # Apply causal mask: only attend to previous positions |
| 120 | + scores = torch.where( |
| 121 | + (tile_q.index.unsqueeze(1) > tile_kv.index.unsqueeze(0)) |
| 122 | + & mask_q[:, None] |
| 123 | + & mask_kv[None, :], |
| 124 | + scores, |
| 125 | + 0.0, |
| 126 | + ) |
| 127 | + |
| 128 | + acc += torch.matmul(scores.to(v.dtype), v_blk) |
| 129 | + |
| 130 | + # Store result |
| 131 | + out[tile_q.index + starts, tile_h.begin, :] = acc.to(out.dtype) |
| 132 | + |
| 133 | + return out |
| 134 | + |
| 135 | + |
| 136 | +def jagged_attention_wrapper( |
| 137 | + q: torch.Tensor, |
| 138 | + k: torch.Tensor, |
| 139 | + v: torch.Tensor, |
| 140 | + seq_offsets: torch.Tensor, |
| 141 | + num_targets: torch.Tensor | None, |
| 142 | + max_seq_len: int, |
| 143 | +) -> torch.Tensor: |
| 144 | + """Wrapper function for jagged attention kernel""" |
| 145 | + return _helion_jagged_attention_kernel( |
| 146 | + max_seq_len=max_seq_len, |
| 147 | + alpha=1.0 / v.size(2) ** 2, |
| 148 | + q=q, |
| 149 | + k=k, |
| 150 | + v=v, |
| 151 | + seq_offsets=seq_offsets, |
| 152 | + ) |
| 153 | + |
| 154 | + |
| 155 | +def test( |
| 156 | + batch_size: int, |
| 157 | + max_seq_len: int, |
| 158 | + heads: int, |
| 159 | + head_dim: int, |
| 160 | + dtype: torch.dtype = torch.bfloat16, |
| 161 | + device: torch.device | str = "cuda", |
| 162 | +) -> None: |
| 163 | + """ |
| 164 | + Test the jagged HSTU attention kernel implementation. |
| 165 | +
|
| 166 | + Args: |
| 167 | + batch_size: Number of sequences in the batch |
| 168 | + max_seq_len: Maximum sequence length |
| 169 | + heads: Number of attention heads |
| 170 | + head_dim: Dimension of each attention head |
| 171 | + dtype: Data type for the tensors |
| 172 | + device: Device to run the test on |
| 173 | + """ |
| 174 | + device = torch.device(device) |
| 175 | + |
| 176 | + # Generate random sequence lengths |
| 177 | + min_seq_len = max_seq_len // 2 |
| 178 | + seq_lengths = torch.randint( |
| 179 | + min_seq_len, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device |
| 180 | + ) |
| 181 | + seq_offsets = torch.cat( |
| 182 | + [ |
| 183 | + torch.tensor([0], dtype=torch.int32, device=device), |
| 184 | + torch.cumsum(seq_lengths, dim=0), |
| 185 | + ] |
| 186 | + ) |
| 187 | + total_seq_len = int(seq_offsets[-1].item()) |
| 188 | + |
| 189 | + # q, k, v: [total_seq_len, heads, head_dim] |
| 190 | + q = torch.randn( |
| 191 | + (total_seq_len, heads, head_dim), |
| 192 | + dtype=dtype, |
| 193 | + device=device, |
| 194 | + requires_grad=True, |
| 195 | + ) |
| 196 | + k = torch.randn( |
| 197 | + (total_seq_len, heads, head_dim), |
| 198 | + dtype=dtype, |
| 199 | + device=device, |
| 200 | + requires_grad=True, |
| 201 | + ) |
| 202 | + v = torch.randn( |
| 203 | + (total_seq_len, heads, head_dim), |
| 204 | + dtype=dtype, |
| 205 | + device=device, |
| 206 | + requires_grad=True, |
| 207 | + ) |
| 208 | + |
| 209 | + baselines = { |
| 210 | + "torch": reference_jagged_hstu_kernel_pytorch, |
| 211 | + } |
| 212 | + if HAS_HAMMER: |
| 213 | + |
| 214 | + def _triton_hstu_mha( |
| 215 | + q: torch.Tensor, |
| 216 | + k: torch.Tensor, |
| 217 | + v: torch.Tensor, |
| 218 | + seq_offsets: torch.Tensor, |
| 219 | + num_targets: torch.Tensor | None, |
| 220 | + max_seq_len: int, |
| 221 | + ) -> torch.Tensor: |
| 222 | + return triton_hstu_mha( # pyright: ignore[reportPossiblyUnboundVariable,reportCallIssue] |
| 223 | + max_seq_len, |
| 224 | + alpha=1.0 / v.size(2) ** 2, |
| 225 | + q=q, |
| 226 | + k=k, |
| 227 | + v=v, |
| 228 | + seq_offsets=seq_offsets, |
| 229 | + num_targets=num_targets, |
| 230 | + max_attn_len=0, |
| 231 | + contextual_seq_len=0, |
| 232 | + ) |
| 233 | + |
| 234 | + baselines["tritonbench"] = _triton_hstu_mha |
| 235 | + |
| 236 | + run_example( |
| 237 | + jagged_attention_wrapper, |
| 238 | + baselines, |
| 239 | + (q, k, v, seq_offsets, None, max_seq_len), |
| 240 | + ) |
| 241 | + |
| 242 | + |
| 243 | +def main() -> None: |
| 244 | + """ |
| 245 | + Main entry point for testing the simplified jagged HSTU attention kernel. |
| 246 | + """ |
| 247 | + test(batch_size=1024, max_seq_len=1024, heads=4, head_dim=128, dtype=torch.bfloat16) |
| 248 | + |
| 249 | + |
| 250 | +if __name__ == "__main__": |
| 251 | + main() |
0 commit comments