@@ -194,14 +194,14 @@ def __init__(self, model_args: TransformerModelArgs):
194194 model_args .n_heads * self .head_dim , model_args .dim , bias = False
195195 )
196196
197- self .use_flex_attn = model_args .attention_type == "flex"
198- self .use_varlen_attn = model_args . attention_type == "varlen"
199- if self . use_flex_attn :
200- self .inner_attention = FlexAttentionWrapper ()
201- elif self . use_varlen_attn :
202- self .inner_attention = VarlenAttentionWrapper ()
203- else :
204- self .inner_attention = ScaledDotProductAttentionWrapper ()
197+ self .attn_type = model_args .attention_type
198+ match self .attn_type :
199+ case "flex" :
200+ self .inner_attention = FlexAttentionWrapper ()
201+ case "varlen" :
202+ self .inner_attention = VarlenAttentionWrapper ()
203+ case _ :
204+ self .inner_attention = ScaledDotProductAttentionWrapper ()
205205
206206 def init_weights (self , init_std : float ):
207207 for linear in (self .wq , self .wk , self .wv ):
@@ -246,22 +246,23 @@ def forward(
246246 xk = keys .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
247247 xv = values .transpose (1 , 2 ) # (bs, n_local_heads, seqlen, head_dim)
248248
249- if self .use_flex_attn :
250- assert isinstance (attention_masks , BlockMask ), attention_masks
251- output = self .inner_attention (xq , xk , xv , block_mask = attention_masks )
252- elif self .use_varlen_attn :
253- assert isinstance (attention_masks , VarlenMetadata ), attention_masks
254- output = self .inner_attention (
255- xq ,
256- xk ,
257- xv ,
258- self .head_dim ,
259- attention_masks ,
260- is_causal = True ,
261- )
262- else :
263- assert attention_masks is None
264- output = self .inner_attention (xq , xk , xv )
249+ match self .attn_type :
250+ case "flex" :
251+ assert isinstance (attention_masks , BlockMask ), attention_masks
252+ output = self .inner_attention (xq , xk , xv , block_mask = attention_masks )
253+ case "varlen" :
254+ assert isinstance (attention_masks , VarlenMetadata ), attention_masks
255+ output = self .inner_attention (
256+ xq ,
257+ xk ,
258+ xv ,
259+ self .head_dim ,
260+ attention_masks ,
261+ is_causal = True ,
262+ )
263+ case _:
264+ assert attention_masks is None
265+ output = self .inner_attention (xq , xk , xv )
265266
266267 output = output .transpose (
267268 1 , 2
@@ -468,7 +469,7 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
468469 def _get_flex_attention_masks (
469470 self ,
470471 input_batch : torch .Tensor ,
471- eos_id : int ,
472+ tokenizer : BaseTokenizer ,
472473 extra_inputs : dict [str , torch .Tensor ] | None = None ,
473474 ) -> AttentionMasksType :
474475 mask_mods = [get_causal_mask_mod ()]
@@ -478,7 +479,7 @@ def _get_flex_attention_masks(
478479 B = 1
479480 case "block_causal" :
480481 B = input_batch .shape [0 ]
481- mask_mods .append (get_document_mask_mod (input_batch , eos_id ))
482+ mask_mods .append (get_document_mask_mod (input_batch , tokenizer . eos_id ))
482483 case _:
483484 raise ValueError (
484485 f"Unknown attention mask type: { self .model_args .attn_mask_type } "
@@ -505,7 +506,7 @@ def get_attention_masks(
505506 match self .model_args .attention_type :
506507 case "flex" :
507508 return self ._get_flex_attention_masks (
508- input_batch , tokenizer . eos_id , extra_inputs
509+ input_batch , tokenizer , extra_inputs
509510 )
510511 case "varlen" :
511512 return self ._get_varlen_attention_masks (
0 commit comments