Skip to content

Commit de416f9

Browse files
committed
remove explicit mask def
1 parent c9b6d5c commit de416f9

File tree

4 files changed

+20
-12
lines changed

4 files changed

+20
-12
lines changed

torchtitan/models/attention.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import functools
1010
from collections.abc import Callable
11-
from typing import ClassVar
11+
from typing import ClassVar, NamedTuple
1212

1313
import torch
1414
import torch.nn.functional as F
@@ -20,13 +20,14 @@
2020
flex_attention,
2121
)
2222

23-
from torch.nn.attention.varlen import varlen_attn, VarlenMetadata
23+
from torch.nn.attention.varlen import varlen_attn
2424

2525

2626
__all__ = [
2727
"FlexAttentionWrapper",
2828
"ScaledDotProductAttentionWrapper",
2929
"VarlenAttentionWrapper",
30+
"VarlenMetadata",
3031
"get_causal_mask_mod",
3132
"get_document_mask_mod",
3233
"get_sliding_window_mask_mod",
@@ -35,6 +36,18 @@
3536
]
3637

3738

39+
class VarlenMetadata(NamedTuple):
40+
"""
41+
Cumulative sequence positions for queries and keys/values.
42+
43+
"""
44+
45+
cu_seq_q: torch.Tensor
46+
cu_seq_k: torch.Tensor
47+
max_q: int
48+
max_k: int
49+
50+
3851
class VarlenAttentionWrapper(torch.nn.Module):
3952
_compiled_varlen_attn: ClassVar[Callable] = torch.compile(
4053
varlen_attn, mode="max-autotune-no-cudagraphs"

torchtitan/models/llama3/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
vocab_size=2048,
4747
rope_theta=500000,
4848
use_varlen_attn=True,
49-
attn_mask_type="varlen_attn",
5049
),
5150
"8B": TransformerModelArgs(
5251
dim=4096,
@@ -77,7 +76,6 @@
7776
multiple_of=1024,
7877
rope_theta=500000,
7978
use_varlen_attn=True,
80-
attn_mask_type="varlen_attn",
8179
),
8280
"70B": TransformerModelArgs(
8381
dim=8192,

torchtitan/models/llama3/model/model.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@
1313
from torch import nn
1414
from torch.nn.attention.flex_attention import and_masks, BlockMask
1515

16-
from torch.nn.attention.varlen import varlen_attn, VarlenMetadata
17-
1816
from torchtitan.components.tokenizer import BaseTokenizer
1917
from torchtitan.models.attention import (
2018
create_attention_mask,
@@ -23,6 +21,7 @@
2321
get_causal_mask_mod,
2422
get_document_mask_mod,
2523
ScaledDotProductAttentionWrapper,
24+
VarlenAttentionWrapper,
2625
)
2726
from torchtitan.protocols.model import AttentionMasksType
2827
from torchtitan.protocols.train_spec import ModelProtocol
@@ -199,7 +198,7 @@ def __init__(self, model_args: TransformerModelArgs):
199198
if self.use_flex_attn:
200199
self.inner_attention = FlexAttentionWrapper()
201200
elif self.use_varlen_attn:
202-
self.inner_attention = varlen_attn
201+
self.inner_attention = VarlenAttentionWrapper()
203202
else:
204203
self.inner_attention = ScaledDotProductAttentionWrapper()
205204

@@ -251,8 +250,6 @@ def forward(
251250
assert isinstance(attention_masks, BlockMask), attention_masks
252251
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
253252
elif self.use_varlen_attn:
254-
assert isinstance(attention_masks, VarlenMetadata), attention_masks
255-
256253
cu_seq_q = attention_masks.cu_seq_q
257254
cu_seq_k = attention_masks.cu_seq_k
258255
max_q = attention_masks.max_q
@@ -495,14 +492,14 @@ def get_attention_masks(
495492
extra_inputs: dict[str, torch.Tensor] | None = None,
496493
) -> AttentionMasksType:
497494
mask_mods = [get_causal_mask_mod()]
495+
if self.model_args.use_varlen_attn:
496+
return create_varlen_cu_seqs(input_batch, tokenizer.eos_id)
498497
match self.model_args.attn_mask_type:
499498
case "causal":
500499
B = 1
501500
case "block_causal":
502501
B = input_batch.shape[0]
503502
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
504-
case "varlen_attn":
505-
return create_varlen_cu_seqs(input_batch, tokenizer.eos_id)
506503
case _:
507504
raise ValueError(
508505
f"Unknown attention mask type: {self.model_args.attn_mask_type}"

torchtitan/protocols/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@
1212
import torch.nn as nn
1313

1414
from torch.nn.attention.flex_attention import BlockMask
15-
from torch.nn.attention.varlen import VarlenMetadata
1615

1716
from torchtitan.components.tokenizer import BaseTokenizer
1817

1918
from torchtitan.config import JobConfig
19+
from torchtitan.models.attention import VarlenMetadata
2020

2121

2222
AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata

0 commit comments

Comments
 (0)