|
8 | 8 |
|
9 | 9 | import functools |
10 | 10 | from collections.abc import Callable |
11 | | -from typing import ClassVar |
| 11 | +from typing import ClassVar, NamedTuple |
12 | 12 |
|
13 | 13 | import torch |
14 | 14 | import torch.nn.functional as F |
|
20 | 20 | flex_attention, |
21 | 21 | ) |
22 | 22 |
|
| 23 | +from torch.nn.attention.varlen import varlen_attn |
| 24 | + |
23 | 25 |
|
24 | 26 | __all__ = [ |
25 | 27 | "FlexAttentionWrapper", |
26 | 28 | "ScaledDotProductAttentionWrapper", |
| 29 | + "VarlenAttentionWrapper", |
| 30 | + "VarlenMetadata", |
27 | 31 | "get_causal_mask_mod", |
28 | 32 | "get_document_mask_mod", |
29 | 33 | "get_sliding_window_mask_mod", |
|
32 | 36 | ] |
33 | 37 |
|
34 | 38 |
|
| 39 | +class VarlenMetadata(NamedTuple): |
| 40 | + """ |
| 41 | + Cumulative sequence positions for queries and keys/values. |
| 42 | +
|
| 43 | + """ |
| 44 | + |
| 45 | + cu_seq_q: torch.Tensor |
| 46 | + cu_seq_k: torch.Tensor |
| 47 | + max_q: int |
| 48 | + max_k: int |
| 49 | + |
| 50 | + |
| 51 | +class VarlenAttentionWrapper(torch.nn.Module): |
| 52 | + _compiled_varlen_attn: ClassVar[Callable] = torch.compile( |
| 53 | + varlen_attn, mode="max-autotune-no-cudagraphs" |
| 54 | + ) |
| 55 | + |
| 56 | + def forward( |
| 57 | + self, |
| 58 | + xq: torch.Tensor, |
| 59 | + xk: torch.Tensor, |
| 60 | + xv: torch.Tensor, |
| 61 | + head_dim: torch.Tensor, |
| 62 | + attention_masks: VarlenMetadata, |
| 63 | + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: |
| 64 | + cu_seq_q = attention_masks.cu_seq_q |
| 65 | + cu_seq_k = attention_masks.cu_seq_k |
| 66 | + max_q = attention_masks.max_q |
| 67 | + max_k = attention_masks.max_k |
| 68 | + |
| 69 | + n_local_heads = xq.shape[1] |
| 70 | + xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim) |
| 71 | + xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim) |
| 72 | + xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim) |
| 73 | + |
| 74 | + return VarlenAttentionWrapper._compiled_varlen_attn( |
| 75 | + xq_packed, |
| 76 | + xk_packed, |
| 77 | + xv_packed, |
| 78 | + cu_seq_q, |
| 79 | + cu_seq_k, |
| 80 | + max_q, |
| 81 | + max_k, |
| 82 | + is_causal=True, |
| 83 | + ) |
| 84 | + |
| 85 | + |
35 | 86 | class FlexAttentionWrapper(torch.nn.Module): |
36 | 87 | """Wrapper around `flex_attention` to make it torch.compile and CP compatible. |
37 | 88 |
|
@@ -66,7 +117,6 @@ def forward( |
66 | 117 | # `FlexAttentionWrapper._compiled_flex_attn` is correct. |
67 | 118 | # 3. Used `return_lse` instead of `return_aux` because of easier TP module notation |
68 | 119 | # to convert `lse` to be DTensor. |
69 | | - |
70 | 120 | return FlexAttentionWrapper._compiled_flex_attn( |
71 | 121 | q, |
72 | 122 | k, |
@@ -226,3 +276,60 @@ def create_attention_mask(*args, **kwargs): |
226 | 276 | arguments. |
227 | 277 | """ |
228 | 278 | return _compiled_create_block_mask(*args, **kwargs) |
| 279 | + |
| 280 | + |
| 281 | +def create_varlen_metadata_for_document( |
| 282 | + input_batch: torch.Tensor, eos_id: int |
| 283 | +) -> VarlenMetadata: |
| 284 | + """ |
| 285 | + Creates cumulative sequence length indices needed for variable length attention |
| 286 | +
|
| 287 | + Args: |
| 288 | + input_batch |
| 289 | + eos_id: the EOS id marker |
| 290 | +
|
| 291 | + Returns: |
| 292 | + VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len |
| 293 | + """ |
| 294 | + batch_size, seq_len = input_batch.shape |
| 295 | + device = input_batch.device |
| 296 | + cu_seqlens_list, all_seq_lengths = [], [] |
| 297 | + offset = 0 |
| 298 | + max_seqlen = 0 |
| 299 | + |
| 300 | + for b in range(batch_size): |
| 301 | + tokens = input_batch[b] |
| 302 | + eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32) |
| 303 | + sample_cu_seqlens = torch.cat( |
| 304 | + [ |
| 305 | + torch.tensor([0], dtype=torch.int32, device=device), |
| 306 | + eos_positions + 1, |
| 307 | + torch.tensor([seq_len], dtype=torch.int32, device=device), |
| 308 | + ] |
| 309 | + ) |
| 310 | + sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens) |
| 311 | + |
| 312 | + seq_lengths = torch.diff(sample_cu_seqlens) |
| 313 | + all_seq_lengths.append(seq_lengths) |
| 314 | + |
| 315 | + cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset |
| 316 | + cu_seqlens_list.append(cu_seqlens_adjusted) |
| 317 | + |
| 318 | + offset += seq_len |
| 319 | + |
| 320 | + packed_cu_seqlens = torch.cat( |
| 321 | + cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)] |
| 322 | + ) |
| 323 | + |
| 324 | + max_seqlen = 0 |
| 325 | + if len(all_seq_lengths) > 0: |
| 326 | + all_seq_lengths = torch.cat(all_seq_lengths) |
| 327 | + # device to host sync but only done once per model forward |
| 328 | + max_seqlen = all_seq_lengths.max().item() |
| 329 | + |
| 330 | + return VarlenMetadata( |
| 331 | + cu_seq_q=packed_cu_seqlens, |
| 332 | + cu_seq_k=packed_cu_seqlens, |
| 333 | + max_q=max_seqlen, |
| 334 | + max_k=max_seqlen, |
| 335 | + ) |
0 commit comments