|
13 | 13 | from torch import nn |
14 | 14 | from torch.nn.attention.flex_attention import and_masks, BlockMask |
15 | 15 |
|
16 | | -from torch.nn.attention.varlen import varlen_attn, VarlenMetadata |
| 16 | +from torch.nn.attention.varlen import varlen_attn |
17 | 17 |
|
18 | 18 | from torchtitan.components.tokenizer import BaseTokenizer |
19 | 19 | from torchtitan.models.attention import ( |
@@ -251,8 +251,6 @@ def forward( |
251 | 251 | assert isinstance(attention_masks, BlockMask), attention_masks |
252 | 252 | output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) |
253 | 253 | elif self.use_varlen_attn: |
254 | | - assert isinstance(attention_masks, VarlenMetadata), attention_masks |
255 | | - |
256 | 254 | cu_seq_q = attention_masks.cu_seq_q |
257 | 255 | cu_seq_k = attention_masks.cu_seq_k |
258 | 256 | max_q = attention_masks.max_q |
@@ -495,14 +493,14 @@ def get_attention_masks( |
495 | 493 | extra_inputs: dict[str, torch.Tensor] | None = None, |
496 | 494 | ) -> AttentionMasksType: |
497 | 495 | mask_mods = [get_causal_mask_mod()] |
| 496 | + if self.model_args.use_varlen_attn: |
| 497 | + return create_varlen_cu_seqs(input_batch, tokenizer.eos_id) |
498 | 498 | match self.model_args.attn_mask_type: |
499 | 499 | case "causal": |
500 | 500 | B = 1 |
501 | 501 | case "block_causal": |
502 | 502 | B = input_batch.shape[0] |
503 | 503 | mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) |
504 | | - case "varlen_attn": |
505 | | - return create_varlen_cu_seqs(input_batch, tokenizer.eos_id) |
506 | 504 | case _: |
507 | 505 | raise ValueError( |
508 | 506 | f"Unknown attention mask type: {self.model_args.attn_mask_type}" |
|
0 commit comments