|
| 1 | +""" |
| 2 | +Simplified Jagged HSTU Attention Forward Example |
| 3 | +=============================================== |
| 4 | +
|
| 5 | +This example demonstrates a simplified version of jagged HSTU attention using Helion. |
| 6 | +""" |
| 7 | + |
| 8 | +# %% |
| 9 | +# Imports |
| 10 | +# ------- |
| 11 | +from __future__ import annotations |
| 12 | + |
| 13 | +from typing import Any |
| 14 | + |
| 15 | +import torch |
| 16 | + |
| 17 | +import helion |
| 18 | +from helion._testing import run_example |
| 19 | +import helion.language as hl |
| 20 | + |
| 21 | +try: |
| 22 | + from generative_recommenders.ops.triton.triton_hstu_attention import ( |
| 23 | + triton_hstu_mha, # pyright: ignore[reportMissingImports] |
| 24 | + ) |
| 25 | + |
| 26 | + HAS_HAMMER = True |
| 27 | +except ImportError: |
| 28 | + HAS_HAMMER = False |
| 29 | + |
| 30 | + |
| 31 | +def generate_inputs() -> dict[str, Any]: |
| 32 | + """Generate small inputs for HSTU attention for easier verification and testing""" |
| 33 | + |
| 34 | + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| 35 | + dtype = torch.bfloat16 |
| 36 | + |
| 37 | + batch_size = 1024 |
| 38 | + max_seq_len = 1024 # N parameter |
| 39 | + heads = 4 |
| 40 | + head_dim = 128 |
| 41 | + |
| 42 | + # Generate random sequence lengths |
| 43 | + min_seq_len = max_seq_len // 2 |
| 44 | + seq_lengths = torch.randint( |
| 45 | + min_seq_len, max_seq_len + 1, (batch_size,), dtype=torch.int32, device=device |
| 46 | + ) |
| 47 | + seq_offsets = torch.cat( |
| 48 | + [ |
| 49 | + torch.tensor([0], dtype=torch.int32, device=device), |
| 50 | + torch.cumsum(seq_lengths, dim=0), |
| 51 | + ] |
| 52 | + ) |
| 53 | + total_seq_len = int(seq_offsets[-1].item()) |
| 54 | + |
| 55 | + # Generate tensors with ragged sequence length |
| 56 | + # q, k, v: [total_seq_len, heads, head_dim] |
| 57 | + q = torch.randn( |
| 58 | + (total_seq_len, heads, head_dim), |
| 59 | + dtype=dtype, |
| 60 | + device=device, |
| 61 | + requires_grad=True, |
| 62 | + ) |
| 63 | + k = torch.randn( |
| 64 | + (total_seq_len, heads, head_dim), |
| 65 | + dtype=dtype, |
| 66 | + device=device, |
| 67 | + requires_grad=True, |
| 68 | + ) |
| 69 | + v = torch.randn( |
| 70 | + (total_seq_len, heads, head_dim), |
| 71 | + dtype=dtype, |
| 72 | + device=device, |
| 73 | + requires_grad=True, |
| 74 | + ) |
| 75 | + |
| 76 | + # Parameters |
| 77 | + alpha = 1.0 / (head_dim**0.5) # Scaling factor |
| 78 | + invalid_attn_mask_type = "lower_triangular" |
| 79 | + |
| 80 | + # Optional parameters (set to None for simplicity) |
| 81 | + num_targets = None |
| 82 | + attn_scale = None |
| 83 | + attn_bias = None |
| 84 | + seq2_offsets = None |
| 85 | + |
| 86 | + # Integer parameters |
| 87 | + max_attn_len = 0 |
| 88 | + contextual_seq_len = 0 |
| 89 | + sort_by_length = False |
| 90 | + full_attn_size = 0 |
| 91 | + |
| 92 | + return { |
| 93 | + "N": max_seq_len, |
| 94 | + "alpha": alpha, |
| 95 | + "q": q, |
| 96 | + "k": k, |
| 97 | + "v": v, |
| 98 | + "seq_offsets": seq_offsets, |
| 99 | + "invalid_attn_mask_type": invalid_attn_mask_type, |
| 100 | + "num_targets": num_targets, |
| 101 | + "attn_scale": attn_scale, |
| 102 | + "attn_bias": attn_bias, |
| 103 | + "seq2_offsets": seq2_offsets, |
| 104 | + "max_attn_len": max_attn_len, |
| 105 | + "contextual_seq_len": contextual_seq_len, |
| 106 | + "sort_by_length": sort_by_length, |
| 107 | + "full_attn_size": full_attn_size, |
| 108 | + } |
| 109 | + |
| 110 | + |
| 111 | +def reference_jagged_hstu_kernel_pytorch(inputs: dict[str, Any]) -> torch.Tensor: |
| 112 | + """Simple PyTorch implementation of HSTU ragged attention using direct tensor slicing""" |
| 113 | + q = inputs["q"] |
| 114 | + k = inputs["k"] |
| 115 | + v = inputs["v"] |
| 116 | + seq_offsets = inputs["seq_offsets"] |
| 117 | + alpha = inputs["alpha"] |
| 118 | + N = inputs["N"] |
| 119 | + |
| 120 | + # Initialize output |
| 121 | + output = torch.zeros_like(v) |
| 122 | + |
| 123 | + # Scale factor |
| 124 | + scale = 1.0 / N |
| 125 | + |
| 126 | + # Compute per-batch sequence lengths |
| 127 | + seq_lens = seq_offsets[1:] - seq_offsets[:-1] |
| 128 | + |
| 129 | + q_split = torch.split(q, seq_lens.tolist(), dim=0) |
| 130 | + k_split = torch.split(k, seq_lens.tolist(), dim=0) |
| 131 | + v_split = torch.split(v, seq_lens.tolist(), dim=0) |
| 132 | + |
| 133 | + # Process each sequence in the batch using direct tensor slicing |
| 134 | + for i, (q_batch, k_batch, v_batch) in enumerate( |
| 135 | + zip(q_split, k_split, v_split, strict=False) |
| 136 | + ): |
| 137 | + q_batch = q_batch.transpose(0, 1) # [heads, seq_len, head_dim] |
| 138 | + k_batch = k_batch.permute(1, 2, 0) # [heads, head_dim, seq_len] |
| 139 | + v_batch = v_batch.transpose(0, 1) # [heads, seq_len, head_dim] |
| 140 | + |
| 141 | + # Compute attention scores using batch matrix multiplication |
| 142 | + scores = torch.bmm(q_batch, k_batch) * alpha |
| 143 | + |
| 144 | + # Apply SiLU activation |
| 145 | + scores = (scores / (1.0 + torch.exp(-scores))) * scale |
| 146 | + |
| 147 | + # Apply invalid mask |
| 148 | + if inputs["invalid_attn_mask_type"] == "lower_triangular": |
| 149 | + invalid_mask = torch.tril( |
| 150 | + torch.ones_like(scores, dtype=torch.bool), diagonal=0 |
| 151 | + ) |
| 152 | + scores = torch.where(invalid_mask, scores, torch.zeros_like(scores)) |
| 153 | + |
| 154 | + # Compute and store output |
| 155 | + output_batch = torch.bmm(scores, v_batch) |
| 156 | + output[seq_offsets[i] : seq_offsets[i + 1]] = output_batch.transpose(0, 1) |
| 157 | + |
| 158 | + return output |
| 159 | + |
| 160 | + |
| 161 | +@helion.kernel() |
| 162 | +def _helion_ragged_attention_kernel( |
| 163 | + q: torch.Tensor, |
| 164 | + k: torch.Tensor, |
| 165 | + v: torch.Tensor, |
| 166 | + seq_offsets: torch.Tensor, |
| 167 | + alpha: float, |
| 168 | + invalid_mask_type: str, |
| 169 | + max_seq_len_tensor: torch.Tensor, |
| 170 | +) -> torch.Tensor: |
| 171 | + max_seq_len = max_seq_len_tensor.numel() |
| 172 | + scale = 1.0 / max_seq_len |
| 173 | + |
| 174 | + num_heads = hl.specialize(q.size(1)) |
| 175 | + num_batches = hl.specialize(seq_offsets.size(0) - 1) |
| 176 | + dimV = hl.specialize(v.size(2)) |
| 177 | + |
| 178 | + out = torch.zeros_like(v) |
| 179 | + |
| 180 | + # --- Tile over batch batch, head, sequence --- |
| 181 | + for tile_b, tile_h, tile_q in hl.tile( |
| 182 | + [num_batches, num_heads, max_seq_len], block_size=[1, 1, None] |
| 183 | + ): |
| 184 | + starts = seq_offsets[tile_b.begin] |
| 185 | + ends = seq_offsets[tile_b.begin + 1] |
| 186 | + seq_len = ends - starts |
| 187 | + |
| 188 | + if tile_q.begin < seq_len: |
| 189 | + mask_q = tile_q.index < seq_len |
| 190 | + q_blk = q[tile_q.index + starts, tile_h.begin, :] |
| 191 | + acc = hl.zeros([tile_q, dimV], dtype=torch.float32) |
| 192 | + |
| 193 | + if invalid_mask_type == "lower_triangular": |
| 194 | + low = 0 |
| 195 | + high = tile_q.end |
| 196 | + else: |
| 197 | + low = 0 |
| 198 | + high = seq_len |
| 199 | + |
| 200 | + for tile_kv in hl.tile(low, high, block_size=None): |
| 201 | + mask_kv = tile_kv.index < seq_len |
| 202 | + k_blk = k[tile_kv.index + starts, tile_h.begin, :] |
| 203 | + v_blk = v[tile_kv.index + starts, tile_h.begin, :] |
| 204 | + |
| 205 | + scores = ( |
| 206 | + torch.nn.functional.silu(torch.matmul(q_blk, k_blk.T) * alpha) |
| 207 | + * scale |
| 208 | + ) |
| 209 | + |
| 210 | + if invalid_mask_type == "lower_triangular": |
| 211 | + scores = torch.where( |
| 212 | + (tile_q.index.unsqueeze(1) > tile_kv.index.unsqueeze(0)) |
| 213 | + & mask_q[:, None] |
| 214 | + & mask_kv[None, :], |
| 215 | + scores, |
| 216 | + 0.0, |
| 217 | + ) |
| 218 | + |
| 219 | + acc += torch.matmul(scores.to(v.dtype), v_blk) |
| 220 | + |
| 221 | + # Store result |
| 222 | + out[tile_q.index + starts, tile_h.begin, :] = acc |
| 223 | + |
| 224 | + return out |
| 225 | + |
| 226 | + |
| 227 | +def helion_ragged_attention_function(inputs: dict[str, Any]) -> torch.Tensor: |
| 228 | + """ |
| 229 | + Wrapper function for the Helion ragged attention kernel. |
| 230 | + """ |
| 231 | + |
| 232 | + return _helion_ragged_attention_kernel( |
| 233 | + q=inputs["q"], |
| 234 | + k=inputs["k"], |
| 235 | + v=inputs["v"], |
| 236 | + seq_offsets=inputs["seq_offsets"], |
| 237 | + alpha=inputs["alpha"], |
| 238 | + invalid_mask_type=inputs["invalid_attn_mask_type"], |
| 239 | + max_seq_len_tensor=torch.empty(inputs["N"], device=inputs["q"].device), |
| 240 | + ) |
| 241 | + |
| 242 | + |
| 243 | +def tritonbench_hstu_attention_function(inputs: dict[str, Any]) -> torch.Tensor: |
| 244 | + """ |
| 245 | + Wrapper function for the tritonbench HSTU attention implementation. |
| 246 | +
|
| 247 | + Args: |
| 248 | + inputs: Dictionary containing all the input parameters |
| 249 | +
|
| 250 | + Returns: |
| 251 | + Output tensor from tritonbench HSTU attention |
| 252 | + """ |
| 253 | + if not HAS_HAMMER: |
| 254 | + # Return a dummy tensor with the same shape as expected output |
| 255 | + return torch.zeros_like(inputs["v"]) |
| 256 | + |
| 257 | + return triton_hstu_mha( # pyright: ignore[reportCallIssue,reportPossiblyUnboundVariable] |
| 258 | + N=inputs["N"], |
| 259 | + alpha=inputs["alpha"], |
| 260 | + q=inputs["q"], |
| 261 | + k=inputs["k"], |
| 262 | + v=inputs["v"], |
| 263 | + seq_offsets=inputs["seq_offsets"], |
| 264 | + num_targets=inputs["num_targets"], |
| 265 | + max_attn_len=inputs["max_attn_len"], |
| 266 | + contextual_seq_len=inputs["contextual_seq_len"], |
| 267 | + sort_by_length=inputs["sort_by_length"], |
| 268 | + ) |
| 269 | + |
| 270 | + |
| 271 | +def main() -> None: |
| 272 | + """ |
| 273 | + Main entry point for testing the simplified jagged HSTU attention kernel. |
| 274 | + """ |
| 275 | + inputs = generate_inputs() |
| 276 | + |
| 277 | + baselines = { |
| 278 | + "torch": lambda inputs: reference_jagged_hstu_kernel_pytorch(inputs), |
| 279 | + } |
| 280 | + if HAS_HAMMER: |
| 281 | + baselines["tritonbench"] = lambda inputs: tritonbench_hstu_attention_function( |
| 282 | + inputs |
| 283 | + ) |
| 284 | + |
| 285 | + run_example( |
| 286 | + lambda inputs: helion_ragged_attention_function(inputs), baselines, (inputs,) |
| 287 | + ) |
| 288 | + |
| 289 | + |
| 290 | +if __name__ == "__main__": |
| 291 | + main() |
0 commit comments