Skip to content

Commit a902cbe

Browse files
committed
remove explicit mask def
1 parent c9b6d5c commit a902cbe

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

torchtitan/models/llama3/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
vocab_size=2048,
4747
rope_theta=500000,
4848
use_varlen_attn=True,
49-
attn_mask_type="varlen_attn",
5049
),
5150
"8B": TransformerModelArgs(
5251
dim=4096,
@@ -77,7 +76,6 @@
7776
multiple_of=1024,
7877
rope_theta=500000,
7978
use_varlen_attn=True,
80-
attn_mask_type="varlen_attn",
8179
),
8280
"70B": TransformerModelArgs(
8381
dim=8192,

torchtitan/models/llama3/model/model.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from torch import nn
1414
from torch.nn.attention.flex_attention import and_masks, BlockMask
1515

16-
from torch.nn.attention.varlen import varlen_attn, VarlenMetadata
16+
from torch.nn.attention.varlen import varlen_attn
1717

1818
from torchtitan.components.tokenizer import BaseTokenizer
1919
from torchtitan.models.attention import (
@@ -251,8 +251,6 @@ def forward(
251251
assert isinstance(attention_masks, BlockMask), attention_masks
252252
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
253253
elif self.use_varlen_attn:
254-
assert isinstance(attention_masks, VarlenMetadata), attention_masks
255-
256254
cu_seq_q = attention_masks.cu_seq_q
257255
cu_seq_k = attention_masks.cu_seq_k
258256
max_q = attention_masks.max_q
@@ -495,14 +493,14 @@ def get_attention_masks(
495493
extra_inputs: dict[str, torch.Tensor] | None = None,
496494
) -> AttentionMasksType:
497495
mask_mods = [get_causal_mask_mod()]
496+
if self.model_args.use_varlen_attn:
497+
return create_varlen_cu_seqs(input_batch, tokenizer.eos_id)
498498
match self.model_args.attn_mask_type:
499499
case "causal":
500500
B = 1
501501
case "block_causal":
502502
B = input_batch.shape[0]
503503
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
504-
case "varlen_attn":
505-
return create_varlen_cu_seqs(input_batch, tokenizer.eos_id)
506504
case _:
507505
raise ValueError(
508506
f"Unknown attention mask type: {self.model_args.attn_mask_type}"

0 commit comments

Comments
 (0)