1616from torchtitan .components .tokenizer import BaseTokenizer
1717from torchtitan .models .attention import (
1818 create_attention_mask ,
19- create_varlen_cu_seqs ,
19+ create_varlen_metadata_for_document ,
2020 FlexAttentionWrapper ,
2121 get_causal_mask_mod ,
2222 get_document_mask_mod ,
2323 ScaledDotProductAttentionWrapper ,
2424 VarlenAttentionWrapper ,
25+ VarlenMetadata ,
2526)
2627from torchtitan .protocols .model import AttentionMasksType
2728from torchtitan .protocols .train_spec import ModelProtocol
@@ -193,8 +194,8 @@ def __init__(self, model_args: TransformerModelArgs):
193194 model_args .n_heads * self .head_dim , model_args .dim , bias = False
194195 )
195196
196- self .use_flex_attn = model_args .use_flex_attn
197- self .use_varlen_attn = model_args .use_varlen_attn
197+ self .use_flex_attn = model_args .attention_type == "flex"
198+ self .use_varlen_attn = model_args .attention_type == "varlen"
198199 if self .use_flex_attn :
199200 self .inner_attention = FlexAttentionWrapper ()
200201 elif self .use_varlen_attn :
@@ -212,7 +213,6 @@ def forward(
212213 x : torch .Tensor ,
213214 freqs_cis : torch .Tensor ,
214215 attention_masks : AttentionMasksType | None ,
215- ** kwargs ,
216216 ):
217217 """
218218 Forward pass of the attention module.
@@ -250,30 +250,13 @@ def forward(
250250 assert isinstance (attention_masks , BlockMask ), attention_masks
251251 output = self .inner_attention (xq , xk , xv , block_mask = attention_masks )
252252 elif self .use_varlen_attn :
253- cu_seq_q = attention_masks .cu_seq_q
254- cu_seq_k = attention_masks .cu_seq_k
255- max_q = attention_masks .max_q
256- max_k = attention_masks .max_k
257-
258- n_local_heads = xq .shape [1 ]
259- xq_packed = (
260- xq .transpose (1 , 2 ).contiguous ().view (- 1 , n_local_heads , self .head_dim )
261- )
262- xk_packed = (
263- xk .transpose (1 , 2 ).contiguous ().view (- 1 , n_local_heads , self .head_dim )
264- )
265- xv_packed = (
266- xv .transpose (1 , 2 ).contiguous ().view (- 1 , n_local_heads , self .head_dim )
267- )
268-
253+ assert isinstance (attention_masks , VarlenMetadata ), attention_masks
269254 output = self .inner_attention (
270- xq_packed ,
271- xk_packed ,
272- xv_packed ,
273- cu_seq_q ,
274- cu_seq_k ,
275- max_q ,
276- max_k ,
255+ xq ,
256+ xk ,
257+ xv ,
258+ self .head_dim ,
259+ attention_masks ,
277260 is_causal = True ,
278261 )
279262 else :
@@ -375,7 +358,6 @@ def forward(
375358 x : torch .Tensor ,
376359 freqs_cis : torch .Tensor ,
377360 attention_masks : AttentionMasksType | None ,
378- ** kwargs ,
379361 ):
380362 """
381363 Perform a forward pass through the TransformerBlock.
@@ -388,9 +370,7 @@ def forward(
388370 torch.Tensor: Output tensor after applying attention and feedforward layers.
389371
390372 """
391- h = x + self .attention (
392- self .attention_norm (x ), freqs_cis , attention_masks , ** kwargs
393- )
373+ h = x + self .attention (self .attention_norm (x ), freqs_cis , attention_masks )
394374 out = h + self .feed_forward (self .ffn_norm (h ))
395375 return out
396376
@@ -485,34 +465,61 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
485465 self .model_args .rope_scaling_args ,
486466 )
487467
488- def get_attention_masks (
468+ def _get_flex_attention_masks (
489469 self ,
490470 input_batch : torch .Tensor ,
491- tokenizer : BaseTokenizer ,
471+ eos_id : int ,
492472 extra_inputs : dict [str , torch .Tensor ] | None = None ,
493473 ) -> AttentionMasksType :
494474 mask_mods = [get_causal_mask_mod ()]
495- if self .model_args .use_varlen_attn :
496- return create_varlen_cu_seqs (input_batch , tokenizer .eos_id )
475+
497476 match self .model_args .attn_mask_type :
498477 case "causal" :
499478 B = 1
500479 case "block_causal" :
501480 B = input_batch .shape [0 ]
502- mask_mods .append (get_document_mask_mod (input_batch , tokenizer . eos_id ))
481+ mask_mods .append (get_document_mask_mod (input_batch , eos_id ))
503482 case _:
504483 raise ValueError (
505484 f"Unknown attention mask type: { self .model_args .attn_mask_type } "
506485 )
486+
507487 return create_attention_mask (
508488 and_masks (* mask_mods ), B , None , input_batch .shape [1 ], input_batch .shape [1 ]
509489 )
510490
491+ def _get_varlen_attention_masks (
492+ self ,
493+ input_batch : torch .Tensor ,
494+ eos_id : int ,
495+ extra_inputs : dict [str , torch .Tensor ] | None = None ,
496+ ) -> AttentionMasksType :
497+ return create_varlen_metadata_for_document (input_batch , eos_id )
498+
499+ def get_attention_masks (
500+ self ,
501+ input_batch : torch .Tensor ,
502+ tokenizer : BaseTokenizer ,
503+ extra_inputs : dict [str , torch .Tensor ] | None = None ,
504+ ) -> AttentionMasksType :
505+ match self .model_args .attention_type :
506+ case "flex" :
507+ return self ._get_flex_attention_masks (
508+ input_batch , tokenizer .eos_id , extra_inputs
509+ )
510+ case "varlen" :
511+ return self ._get_varlen_attention_masks (
512+ input_batch , tokenizer .eos_id , extra_inputs
513+ )
514+ case _:
515+ raise NotImplementedError (
516+ "Only varlen and flex attn masks are supported"
517+ )
518+
511519 def forward (
512520 self ,
513521 tokens : torch .Tensor ,
514522 attention_masks : AttentionMasksType | None = None ,
515- ** kwargs ,
516523 ):
517524 """
518525 Perform a forward pass through the Transformer model.
@@ -531,8 +538,7 @@ def forward(
531538 h = self .tok_embeddings (tokens ) if self .tok_embeddings else tokens
532539
533540 for layer in self .layers .values ():
534- h = layer (h , self .freqs_cis , attention_masks = attention_masks , ** kwargs )
535-
541+ h = layer (h , self .freqs_cis , attention_masks = attention_masks )
536542 h = self .norm (h ) if self .norm else h
537543 output = self .output (h ) if self .output else h
538544 return output
0 commit comments