Skip to content

Commit caafc81

Browse files
committed
attention_type
1 parent de416f9 commit caafc81

File tree

7 files changed

+110
-61
lines changed

7 files changed

+110
-61
lines changed

torchtitan/models/attention.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,18 +55,32 @@ class VarlenAttentionWrapper(torch.nn.Module):
5555

5656
def forward(
5757
self,
58-
q: torch.Tensor,
59-
k: torch.Tensor,
60-
v: torch.Tensor,
61-
cu_seq_q: torch.Tensor,
62-
cu_seq_k: torch.Tensor,
63-
max_q: int,
64-
max_k: int,
58+
xq: torch.Tensor,
59+
xk: torch.Tensor,
60+
xv: torch.Tensor,
61+
head_dim: torch.Tensor,
62+
attention_masks: VarlenMetadata,
6563
is_causal: bool = True,
6664
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
65+
cu_seq_q = attention_masks.cu_seq_q
66+
cu_seq_k = attention_masks.cu_seq_k
67+
max_q = attention_masks.max_q
68+
max_k = attention_masks.max_k
69+
70+
n_local_heads = xq.shape[1]
71+
xq_packed = xq.transpose(1, 2).contiguous().view(-1, n_local_heads, head_dim)
72+
xk_packed = xk.transpose(1, 2).contiguous().view(-1, n_local_heads, head_dim)
73+
xv_packed = xv.transpose(1, 2).contiguous().view(-1, n_local_heads, head_dim)
6774

6875
return VarlenAttentionWrapper._compiled_varlen_attn(
69-
q, k, v, cu_seq_q, cu_seq_k, max_q, max_k, is_causal=True
76+
xq_packed,
77+
xk_packed,
78+
xv_packed,
79+
cu_seq_q,
80+
cu_seq_k,
81+
max_q,
82+
max_k,
83+
is_causal=True,
7084
)
7185

7286

@@ -104,7 +118,6 @@ def forward(
104118
# `FlexAttentionWrapper._compiled_flex_attn` is correct.
105119
# 3. Used `return_lse` instead of `return_aux` because of easier TP module notation
106120
# to convert `lse` to be DTensor.
107-
108121
return FlexAttentionWrapper._compiled_flex_attn(
109122
q,
110123
k,
@@ -266,7 +279,9 @@ def create_attention_mask(*args, **kwargs):
266279
return _compiled_create_block_mask(*args, **kwargs)
267280

268281

269-
def create_varlen_cu_seqs(input_batch: torch.Tensor, eos_id: int) -> VarlenMetadata:
282+
def create_varlen_metadata_for_document(
283+
input_batch: torch.Tensor, eos_id: int
284+
) -> VarlenMetadata:
270285
"""
271286
Creates cumulative sequence length indices needed for variable length attention
272287

torchtitan/models/llama3/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
n_heads=16,
3737
vocab_size=2048,
3838
rope_theta=500000,
39-
use_flex_attn=True,
39+
attention_type="flex",
4040
attn_mask_type="block_causal",
4141
),
4242
"debugmodel_varlen_attn": TransformerModelArgs(
@@ -45,7 +45,7 @@
4545
n_heads=16,
4646
vocab_size=2048,
4747
rope_theta=500000,
48-
use_varlen_attn=True,
48+
attention_type="varlen",
4949
),
5050
"8B": TransformerModelArgs(
5151
dim=4096,
@@ -64,7 +64,7 @@
6464
ffn_dim_multiplier=1.3,
6565
multiple_of=1024,
6666
rope_theta=500000,
67-
use_flex_attn=True,
67+
attention_type="flex",
6868
attn_mask_type="block_causal",
6969
),
7070
"8B_varlen": TransformerModelArgs(
@@ -75,7 +75,7 @@
7575
ffn_dim_multiplier=1.3,
7676
multiple_of=1024,
7777
rope_theta=500000,
78-
use_varlen_attn=True,
78+
attention_type="varlen",
7979
),
8080
"70B": TransformerModelArgs(
8181
dim=8192,

torchtitan/models/llama3/infra/parallelize.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,9 @@ def parallelize_llama(
6767
({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}).
6868
"""
6969

70-
use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
70+
# use_flex_attn = getattr(model.model_args, "use_flex_attn", False)
71+
attn_type = getattr(model.model_args, "attention_type", False)
72+
use_flex_attn = attn_type == "flex"
7173
if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn:
7274
raise NotImplementedError("CP support for FlexAttention is still in progress.")
7375

torchtitan/models/llama3/model/args.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
from dataclasses import dataclass, field
1111

12-
from torch import nn
12+
from typing import Literal
1313

14+
from torch import nn
1415
from torchtitan.config import JobConfig
1516
from torchtitan.models.utils import get_dense_model_nparams_and_flops
1617
from torchtitan.protocols.model import BaseModelArgs
@@ -43,8 +44,8 @@ class TransformerModelArgs(BaseModelArgs):
4344
# `False`, each uses the total number of transformer blocks
4445
depth_init: bool = True
4546

46-
use_flex_attn: bool = False
47-
use_varlen_attn: bool = False
47+
attention_type: Literal["flex", "varlen"] = None
48+
# use_flex_attn: bool = True
4849
attn_mask_type: str = "causal"
4950
eos_id: int = 0
5051

@@ -56,7 +57,10 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None:
5657
)
5758
self.max_seq_len = seq_len
5859

59-
if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn:
60+
if (
61+
job_config.parallelism.context_parallel_degree > 1
62+
and self.attention_type == "flex"
63+
):
6064
raise NotImplementedError(
6165
"CP support for FlexAttention is still in progress."
6266
)

torchtitan/models/llama3/model/model.py

Lines changed: 45 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
from torchtitan.components.tokenizer import BaseTokenizer
1717
from torchtitan.models.attention import (
1818
create_attention_mask,
19-
create_varlen_cu_seqs,
19+
create_varlen_metadata_for_document,
2020
FlexAttentionWrapper,
2121
get_causal_mask_mod,
2222
get_document_mask_mod,
2323
ScaledDotProductAttentionWrapper,
2424
VarlenAttentionWrapper,
25+
VarlenMetadata,
2526
)
2627
from torchtitan.protocols.model import AttentionMasksType
2728
from torchtitan.protocols.train_spec import ModelProtocol
@@ -193,8 +194,8 @@ def __init__(self, model_args: TransformerModelArgs):
193194
model_args.n_heads * self.head_dim, model_args.dim, bias=False
194195
)
195196

196-
self.use_flex_attn = model_args.use_flex_attn
197-
self.use_varlen_attn = model_args.use_varlen_attn
197+
self.use_flex_attn = model_args.attention_type == "flex"
198+
self.use_varlen_attn = model_args.attention_type == "varlen"
198199
if self.use_flex_attn:
199200
self.inner_attention = FlexAttentionWrapper()
200201
elif self.use_varlen_attn:
@@ -212,7 +213,6 @@ def forward(
212213
x: torch.Tensor,
213214
freqs_cis: torch.Tensor,
214215
attention_masks: AttentionMasksType | None,
215-
**kwargs,
216216
):
217217
"""
218218
Forward pass of the attention module.
@@ -250,30 +250,13 @@ def forward(
250250
assert isinstance(attention_masks, BlockMask), attention_masks
251251
output = self.inner_attention(xq, xk, xv, block_mask=attention_masks)
252252
elif self.use_varlen_attn:
253-
cu_seq_q = attention_masks.cu_seq_q
254-
cu_seq_k = attention_masks.cu_seq_k
255-
max_q = attention_masks.max_q
256-
max_k = attention_masks.max_k
257-
258-
n_local_heads = xq.shape[1]
259-
xq_packed = (
260-
xq.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim)
261-
)
262-
xk_packed = (
263-
xk.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim)
264-
)
265-
xv_packed = (
266-
xv.transpose(1, 2).contiguous().view(-1, n_local_heads, self.head_dim)
267-
)
268-
253+
assert isinstance(attention_masks, VarlenMetadata), attention_masks
269254
output = self.inner_attention(
270-
xq_packed,
271-
xk_packed,
272-
xv_packed,
273-
cu_seq_q,
274-
cu_seq_k,
275-
max_q,
276-
max_k,
255+
xq,
256+
xk,
257+
xv,
258+
self.head_dim,
259+
attention_masks,
277260
is_causal=True,
278261
)
279262
else:
@@ -375,7 +358,6 @@ def forward(
375358
x: torch.Tensor,
376359
freqs_cis: torch.Tensor,
377360
attention_masks: AttentionMasksType | None,
378-
**kwargs,
379361
):
380362
"""
381363
Perform a forward pass through the TransformerBlock.
@@ -388,9 +370,7 @@ def forward(
388370
torch.Tensor: Output tensor after applying attention and feedforward layers.
389371
390372
"""
391-
h = x + self.attention(
392-
self.attention_norm(x), freqs_cis, attention_masks, **kwargs
393-
)
373+
h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks)
394374
out = h + self.feed_forward(self.ffn_norm(h))
395375
return out
396376

@@ -485,34 +465,61 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
485465
self.model_args.rope_scaling_args,
486466
)
487467

488-
def get_attention_masks(
468+
def _get_flex_attention_masks(
489469
self,
490470
input_batch: torch.Tensor,
491-
tokenizer: BaseTokenizer,
471+
eos_id: int,
492472
extra_inputs: dict[str, torch.Tensor] | None = None,
493473
) -> AttentionMasksType:
494474
mask_mods = [get_causal_mask_mod()]
495-
if self.model_args.use_varlen_attn:
496-
return create_varlen_cu_seqs(input_batch, tokenizer.eos_id)
475+
497476
match self.model_args.attn_mask_type:
498477
case "causal":
499478
B = 1
500479
case "block_causal":
501480
B = input_batch.shape[0]
502-
mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id))
481+
mask_mods.append(get_document_mask_mod(input_batch, eos_id))
503482
case _:
504483
raise ValueError(
505484
f"Unknown attention mask type: {self.model_args.attn_mask_type}"
506485
)
486+
507487
return create_attention_mask(
508488
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
509489
)
510490

491+
def _get_varlen_attention_masks(
492+
self,
493+
input_batch: torch.Tensor,
494+
eos_id: int,
495+
extra_inputs: dict[str, torch.Tensor] | None = None,
496+
) -> AttentionMasksType:
497+
return create_varlen_metadata_for_document(input_batch, eos_id)
498+
499+
def get_attention_masks(
500+
self,
501+
input_batch: torch.Tensor,
502+
tokenizer: BaseTokenizer,
503+
extra_inputs: dict[str, torch.Tensor] | None = None,
504+
) -> AttentionMasksType:
505+
match self.model_args.attention_type:
506+
case "flex":
507+
return self._get_flex_attention_masks(
508+
input_batch, tokenizer.eos_id, extra_inputs
509+
)
510+
case "varlen":
511+
return self._get_varlen_attention_masks(
512+
input_batch, tokenizer.eos_id, extra_inputs
513+
)
514+
case _:
515+
raise NotImplementedError(
516+
"Only varlen and flex attn masks are supported"
517+
)
518+
511519
def forward(
512520
self,
513521
tokens: torch.Tensor,
514522
attention_masks: AttentionMasksType | None = None,
515-
**kwargs,
516523
):
517524
"""
518525
Perform a forward pass through the Transformer model.
@@ -531,8 +538,7 @@ def forward(
531538
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
532539

533540
for layer in self.layers.values():
534-
h = layer(h, self.freqs_cis, attention_masks=attention_masks, **kwargs)
535-
541+
h = layer(h, self.freqs_cis, attention_masks=attention_masks)
536542
h = self.norm(h) if self.norm else h
537543
output = self.output(h) if self.output else h
538544
return output

torchtitan/protocols/model.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,21 @@ def get_attention_masks(
7171
raise NotImplementedError(
7272
"This model does not support attention masking/Flex Attention."
7373
)
74+
75+
def _get_varlen_attention_masks(
76+
self,
77+
input_batch: torch.Tensor,
78+
eos_id: int,
79+
extra_inputs: dict[str, torch.Tensor] | None = None,
80+
) -> AttentionMasksType:
81+
raise NotImplementedError(
82+
"This model does not support variable length attention."
83+
)
84+
85+
def _get_flex_attention_masks(
86+
self,
87+
input_batch: torch.Tensor,
88+
eos_id: int,
89+
extra_inputs: dict[str, torch.Tensor] | None = None,
90+
) -> AttentionMasksType:
91+
raise NotImplementedError("This model does not support flex attention.")

torchtitan/train.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -454,9 +454,13 @@ def post_dataloading_process(
454454
# extra_kwargs are.
455455
extra_kwargs: dict[str, Any] = {}
456456

457-
if getattr(self.model_args, "use_flex_attn", False) or getattr(
458-
self.model_args, "use_varlen_attn", False
459-
):
457+
attn_type = getattr(self.model_args, "attention_type", False)
458+
use_varlen_attn = attn_type == "varlen"
459+
use_flex_attn = (
460+
getattr(self.model_args, "use_flex_attn", False) or attn_type == "flex"
461+
)
462+
463+
if use_flex_attn or use_varlen_attn:
460464
extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks(
461465
input_batch=inputs,
462466
tokenizer=self.tokenizer,

0 commit comments

Comments
 (0)