From 5e7df8359df3a6ddcd064719470890c1c52f853a Mon Sep 17 00:00:00 2001 From: Vasqu Date: Mon, 12 May 2025 13:47:08 +0200 Subject: [PATCH 01/68] starting attn refactor for encoder decoder models via bart (eager + sdpa) --- .../integrations/sdpa_attention.py | 11 +- src/transformers/models/bart/modeling_bart.py | 140 +++++++++--------- 2 files changed, 78 insertions(+), 73 deletions(-) diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 9c924c048ad5..515eef5ae988 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -30,9 +30,8 @@ def sdpa_attention_forward( key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) - causal_mask = attention_mask - if attention_mask is not None and causal_mask.ndim == 4: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] + if attention_mask is not None and attention_mask.ndim == 4: + attention_mask = attention_mask[:, :, :, : key.shape[-2]] # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -44,7 +43,9 @@ def sdpa_attention_forward( # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool` if is_causal is None: - is_causal = query.shape[2] > 1 and causal_mask is None + # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag + # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns + is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True) # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # We convert it to a bool for the SDPA kernel that only accepts bools. @@ -55,7 +56,7 @@ def sdpa_attention_forward( query, key, value, - attn_mask=causal_mask, + attn_mask=attention_mask, dropout_p=dropout, scale=scaling, is_causal=is_causal, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index a9f9c31ac0f4..902d330f0860 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -17,7 +17,7 @@ import copy import math import warnings -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -42,7 +42,7 @@ Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...utils import auto_docstring, logging from .configuration_bart import BartConfig @@ -105,6 +105,33 @@ def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class BartAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -150,6 +177,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -157,10 +185,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -175,18 +207,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -198,69 +230,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + # TODO: flex attn + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class BartFlashAttention2(BartAttention): @@ -497,7 +501,7 @@ def forward( BART_ATTENTION_CLASSES = { "eager": BartAttention, - "sdpa": BartSdpaAttention, + "sdpa": BartAttention, "flash_attention_2": BartFlashAttention2, } From faf7914112933605f17b15637099df0a7fb1a12b Mon Sep 17 00:00:00 2001 From: Vasqu Date: Mon, 12 May 2025 14:38:09 +0200 Subject: [PATCH 02/68] flash attention works, remove unnecessary code --- src/transformers/models/bart/modeling_bart.py | 234 +----------------- 1 file changed, 1 insertion(+), 233 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 902d330f0860..eeaf0bbfa969 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -267,242 +267,10 @@ def forward( return attn_output, attn_weights, past_key_value -class BartFlashAttention2(BartAttention): - """ - Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # BartFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("BartFlashAttention2 attention does not support output_attentions") - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype - else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( - query_states, - key_states, - value_states, - attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, - ) - - attn_output = attn_output.reshape(bsz, q_len, -1) - attn_output = self.out_proj(attn_output) - - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value - - -class BartSdpaAttention(BartAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - BART_ATTENTION_CLASSES = { "eager": BartAttention, "sdpa": BartAttention, - "flash_attention_2": BartFlashAttention2, + "flash_attention_2": BartAttention, } From 90a90d3fddfaf73aa8ba63234031d98586c12888 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Mon, 12 May 2025 17:53:25 +0200 Subject: [PATCH 03/68] flex attention support for bart!, gotta check if the renaming is not too aggressive --- src/transformers/integrations/__init__.py | 2 +- .../integrations/flex_attention.py | 44 ++++++---- src/transformers/models/bart/modeling_bart.py | 82 ++++++++++++------- 3 files changed, 83 insertions(+), 45 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index 8d03c5cf790e..d4bc3820314b 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -282,7 +282,7 @@ except OptionalDependencyNotAvailable: pass else: - from .flex_attention import make_flex_block_causal_mask + from .flex_attention import make_flex_block_mask else: import sys diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 56c35e8d1956..e0d938661e62 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -36,10 +36,7 @@ if is_torch_flex_attn_available(): - from torch.nn.attention.flex_attention import BlockMask, flex_attention - from torch.nn.attention.flex_attention import ( - create_block_mask as create_block_causal_mask_flex, - ) + from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention class WrappedFlexAttention: @@ -82,17 +79,18 @@ def __call__(self): Offset = Union[torch.Tensor, int] -def make_flex_block_causal_mask( +def make_flex_block_mask( attention_mask_2d: torch.Tensor, attention_chunk_size: Optional[int] = None, query_length=None, key_length=None, offsets: Optional[Tuple[Offset, Offset]] = None, + is_causal: Optional[bool] = True, ) -> "BlockMask": """ - Create a block causal document mask for a batch of sequences, both packed and unpacked. - Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. - The resultant BlockMask is a compressed representation of the full block causal + Create a block (causal) document mask for a batch of sequences, both packed and unpacked. + Create Block (causal) logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`. + The resultant BlockMask is a compressed representation of the full block (causal) mask. BlockMask is essential for performant computation of flex attention. See: https://pytorch.org/blog/flexattention/ @@ -151,7 +149,20 @@ def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx): causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx) return chunk_mask & causal_doc_mask - mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod + def default_mask_mod(batch_idx, head_idx, q_idx, kv_idx): + """ + Utilizes default attention mask to enable encoder and encoder-decoder + attention masks. + """ + document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx] + padding_mask = attention_mask_2d[batch_idx, kv_idx] > 0 + final_mask = padding_mask & document_mask + return final_mask + + if not is_causal: + mask_mod_maybe_combined = default_mask_mod + else: + mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod if offsets is not None: q_offset = offsets[0] @@ -163,7 +174,8 @@ def mask_mod(batch_idx, head_idx, q_idx, kv_idx): return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv) else: mask_mod = mask_mod_maybe_combined - return create_block_causal_mask_flex( + + return create_block_mask( mask_mod=mask_mod, B=batch_size, H=None, # attention head @@ -216,20 +228,20 @@ def flex_attention_forward( **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: block_mask = None - causal_mask = None + score_mask = None if isinstance(attention_mask, BlockMask): block_mask = attention_mask else: - causal_mask = attention_mask + score_mask = attention_mask - if causal_mask is not None: - causal_mask = causal_mask[:, :, :, : key.shape[-2]] + if score_mask is not None: + score_mask = score_mask[:, :, :, : key.shape[-2]] def score_mod(score, batch_idx, head_idx, q_idx, kv_idx): if softcap is not None: score = softcap * torch.tanh(score / softcap) - if causal_mask is not None: - score = score + causal_mask[batch_idx][0][q_idx][kv_idx] + if score_mask is not None: + score = score + score_mask[batch_idx][0][q_idx][kv_idx] if head_mask is not None: score = score + head_mask[batch_idx][head_idx][0][0] return score diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index eeaf0bbfa969..2f5acc88acf5 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -32,7 +32,6 @@ _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -43,12 +42,12 @@ Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...utils import auto_docstring, logging +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_bart import BartConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_mask logger = logging.get_logger(__name__) @@ -166,9 +165,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -230,13 +226,13 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - # TODO: flex attn attention_interface: Callable = eager_attn_forward attention_type = self.config._attn_implementation if self.config._attn_implementation != "eager": if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ "sdpa", "flash_attention_2", + "flex_attention", ]: logger.warning_once( f"Falling back to eager attention because `{attention_type}` does not support" @@ -267,19 +263,12 @@ def forward( return attn_output, attn_weights, past_key_value -BART_ATTENTION_CLASSES = { - "eager": BartAttention, - "sdpa": BartAttention, - "flash_attention_2": BartAttention, -} - - class BartEncoderLayer(nn.Module): def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BartAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -349,7 +338,7 @@ def __init__(self, config: BartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BartAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -362,7 +351,7 @@ def __init__(self, config: BartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BART_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = BartAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -498,6 +487,7 @@ class BartPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): std = self.config.init_std @@ -570,8 +560,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No embed_dim, ) self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False @@ -658,14 +646,22 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # expand attention_mask + _unsupported_features = output_attentions is True or head_mask is not None if attention_mask is not None: - if self._use_flash_attention_2: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: attention_mask = attention_mask if 0 in attention_mask else None - elif self._use_sdpa and head_mask is None and not output_attentions: + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_mask(attention_mask, is_causal=False) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) @@ -754,8 +750,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No config.d_model, ) self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -875,10 +869,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) - if self._use_flash_attention_2: + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -887,6 +882,22 @@ def forward( inputs_embeds, past_key_values_length, ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) else: # 4d mask is passed through the layers attention_mask = _prepare_4d_causal_attention_mask( @@ -895,9 +906,9 @@ def forward( # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -906,6 +917,21 @@ def forward( inputs_embeds.dtype, tgt_len=input_shape[-1], ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + + #encoder_attention_mask = _prepare_4d_attention_mask( + # encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + #) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( From 259258d4c390cf46172d697ee6b3792b86f86626 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Mon, 12 May 2025 18:19:44 +0200 Subject: [PATCH 04/68] some comments --- src/transformers/integrations/flex_attention.py | 1 + src/transformers/models/bart/modeling_bart.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index e0d938661e62..316e20365f79 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -155,6 +155,7 @@ def default_mask_mod(batch_idx, head_idx, q_idx, kv_idx): attention masks. """ document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx] + # kv indexing is crucial in order to work correctly padding_mask = attention_mask_2d[batch_idx, kv_idx] > 0 final_mask = padding_mask & document_mask return final_mask diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 2f5acc88acf5..7523000ff524 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -32,6 +32,7 @@ _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -42,7 +43,8 @@ Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel -from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...processing_utils import Unpack +from ...utils import LossKwargs, auto_docstring, is_torch_flex_attn_available, logging from .configuration_bart import BartConfig @@ -173,7 +175,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, - **kwargs, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -928,10 +932,6 @@ def forward( query_length=input_shape[-1], is_causal=False, ) - - #encoder_attention_mask = _prepare_4d_attention_mask( - # encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - #) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( From afb293459263600144d12415d48caf6bb2760079 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Mon, 12 May 2025 18:45:42 +0200 Subject: [PATCH 05/68] skip flex grad test for standalone as done with the other test --- tests/models/bart/test_modeling_bart.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index d23de5faa306..6399cdf8725d 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -1534,3 +1534,7 @@ def test_decoder_model_attn_mask_past(self): @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return From 25db34af1851179463ddb2b3c81e3c35b7ff899e Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 13 May 2025 10:21:44 +0200 Subject: [PATCH 06/68] revert flex attn rename (for now), sdpa simplify, and todos --- src/transformers/integrations/__init__.py | 2 +- src/transformers/integrations/flex_attention.py | 3 ++- src/transformers/integrations/sdpa_attention.py | 2 +- src/transformers/models/bart/modeling_bart.py | 10 +++++----- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/transformers/integrations/__init__.py b/src/transformers/integrations/__init__.py index d4bc3820314b..8d03c5cf790e 100755 --- a/src/transformers/integrations/__init__.py +++ b/src/transformers/integrations/__init__.py @@ -282,7 +282,7 @@ except OptionalDependencyNotAvailable: pass else: - from .flex_attention import make_flex_block_mask + from .flex_attention import make_flex_block_causal_mask else: import sys diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 316e20365f79..82343cb4d601 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -79,7 +79,8 @@ def __call__(self): Offset = Union[torch.Tensor, int] -def make_flex_block_mask( +# TODO: rename to make_flex_block_mask +def make_flex_block_causal_mask( attention_mask_2d: torch.Tensor, attention_chunk_size: Optional[int] = None, query_length=None, diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 515eef5ae988..5c14df042e89 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -45,7 +45,7 @@ def sdpa_attention_forward( if is_causal is None: # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns - is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True) + is_causal = query.shape[2] > 1 and attention_mask is None and module.is_causal # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # We convert it to a bool for the SDPA kernel that only accepts bools. diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 7523000ff524..f11a96391870 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -49,7 +49,7 @@ if is_torch_flex_attn_available(): - from ...integrations.flex_attention import make_flex_block_mask + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -665,7 +665,7 @@ def forward( and (self.config.attention_dropout == 0 or not self.training) ): if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_mask(attention_mask, is_causal=False) + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) @@ -892,11 +892,11 @@ def forward( and (self.config.attention_dropout == 0 or not self.training) ): if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_mask(attention_mask) + attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) # while we need to create our specific block mask regardless elif attention_mask is None: - attention_mask = make_flex_block_mask( + attention_mask = make_flex_block_causal_mask( torch.ones( size=(input_shape), device=inputs_embeds.device, @@ -927,7 +927,7 @@ def forward( and (self.config.attention_dropout == 0 or not self.training) ): if isinstance(encoder_attention_mask, torch.Tensor): - encoder_attention_mask = make_flex_block_mask( + encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, query_length=input_shape[-1], is_causal=False, From 131de1b0abb5aa990c39bd92a15f1dc63b0b4e16 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 13 May 2025 10:27:08 +0200 Subject: [PATCH 07/68] more todos --- src/transformers/integrations/flex_attention.py | 2 +- src/transformers/models/bart/configuration_bart.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 82343cb4d601..ea2bc1b12ee5 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -79,7 +79,7 @@ def __call__(self): Offset = Union[torch.Tensor, int] -# TODO: rename to make_flex_block_mask +# TODO: rename to make_flex_block_mask for clarity as it's not only causal anymore def make_flex_block_causal_mask( attention_mask_2d: torch.Tensor, attention_chunk_size: Optional[int] = None, diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index 4ce4316e3c03..1dc4c6101d0b 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -107,6 +107,7 @@ class BartConfig(PretrainedConfig): model_type = "bart" keys_to_ignore_at_inference = ["past_key_values"] attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} + # TODO: add tp plan def __init__( self, From c8b8ed6c33dd2425b606cae953b7f9cf0ddb63e1 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 13 May 2025 12:09:30 +0200 Subject: [PATCH 08/68] refactor mask creation for reuse --- src/transformers/models/bart/modeling_bart.py | 212 +++++++++++------- 1 file changed, 130 insertions(+), 82 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index f11a96391870..cd060dc01cb3 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -649,26 +649,12 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask _unsupported_features = output_attentions is True or head_mask is not None - if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + _unsupported_features, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -724,6 +710,33 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class BartDecoder(BartPreTrainedModel): """ @@ -874,69 +887,20 @@ def forward( inputs_embeds = self.embed_tokens(input) _unsupported_features = output_attentions is True or cross_attn_head_mask is not None - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - # Other attention flavors support in-built causal (when `mask is None`) - # while we need to create our specific block mask regardless - elif attention_mask is None: - attention_mask = make_flex_block_causal_mask( - torch.ones( - size=(input_shape), - device=inputs_embeds.device, - ) - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(encoder_attention_mask, torch.Tensor): - encoder_attention_mask = make_flex_block_causal_mask( - encoder_attention_mask, - query_length=input_shape[-1], - is_causal=False, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, + ) # embed positions positions = self.embed_positions(input, past_key_values_length) @@ -1037,6 +1001,90 @@ def forward( cross_attentions=all_cross_attentions, ) + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class BartModel(BartPreTrainedModel): From c0d83b64de5e8f82d11d611bc443f7033e99eca2 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 13 May 2025 13:30:26 +0200 Subject: [PATCH 09/68] modular attempt at biogpt --- .../models/biogpt/modeling_biogpt.py | 356 ++++----- .../models/biogpt/modular_biogpt.py | 726 ++++++++++++++++++ 2 files changed, 877 insertions(+), 205 deletions(-) create mode 100644 src/transformers/models/biogpt/modular_biogpt.py diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 7fdb5a561899..284c6b7c33b1 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -1,3 +1,9 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/biogpt/modular_biogpt.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_biogpt.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 # coding=utf-8 # Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved. # @@ -12,38 +18,38 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""PyTorch BioGPT model.""" import math -from typing import Optional, Tuple, Union +from functools import partial +from typing import Callable, Optional, Tuple, Union import torch -import torch.utils.checkpoint -from torch import nn +import torch.nn as nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, SequenceClassifierOutputWithPast, TokenClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import ( - auto_docstring, - logging, -) +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import LossKwargs, auto_docstring, is_torch_flex_attn_available, logging from .configuration_biogpt import BioGptConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) -# copied from transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding with OPT->BioGpt -# TODO @ArthurZucker bring copied from back class BioGptLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -68,7 +74,6 @@ def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int return super().forward(positions + self.offset) -# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BioGpt class BioGptScaledWordEmbedding(nn.Embedding): """ This module overrides nn.Embeddings' forward by multiplying with embeddings scale. @@ -82,7 +87,33 @@ def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BioGpt +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class BioGptAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -117,9 +148,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -128,6 +156,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -135,10 +166,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -153,18 +188,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -176,182 +211,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->BioGpt -class BioGptSdpaAttention(BioGptAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "BioGptModel is using BioGptSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, ) - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value - - -BIOGPT_ATTENTION_CLASSES = { - "eager": BioGptAttention, - "sdpa": BioGptSdpaAttention, -} + return attn_output, attn_weights, past_key_value class BioGptDecoderLayer(nn.Module): @@ -359,12 +253,13 @@ def __init__(self, config: BioGptConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = BIOGPT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BioGptAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, dropout=config.attention_probs_dropout_prob, is_decoder=True, is_causal=True, + config=config, ) self.dropout = config.hidden_dropout_prob self.activation_fn = ACT2FN[config.hidden_act] @@ -384,6 +279,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -414,6 +310,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + **flash_attn_kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -444,7 +341,9 @@ class BioGptPreTrainedModel(PreTrainedModel): config_class = BioGptConfig base_model_prefix = "biogpt" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" @@ -483,7 +382,6 @@ def __init__(self, config: BioGptConfig): self.layer_norm = nn.LayerNorm(self.embed_dim) self.gradient_checkpointing = False - self._use_sdpa = config._attn_implementation == "sdpa" # Initialize weights and apply final processing self.post_init() @@ -505,7 +403,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, # NOOP kwargs, for now + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -547,16 +445,14 @@ def forward( # embed positions positions = self.embed_positions(attention_mask, past_key_values_length) - if self._use_sdpa and not output_attentions and head_mask is None: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - else: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) hidden_states = inputs_embeds + positions @@ -587,7 +483,7 @@ def forward( if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( - decoder_layer.__call__, + partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, attention_mask, head_mask[idx] if head_mask is not None else None, @@ -603,6 +499,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + **flash_attn_kwargs, ) hidden_states = layer_outputs[0] @@ -635,6 +532,54 @@ def forward( cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + @auto_docstring( custom_intro=""" @@ -672,7 +617,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **kwargs, + **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): @@ -692,6 +637,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + **kwargs, ) sequence_output = outputs[0] diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py new file mode 100644 index 000000000000..018d1143c34c --- /dev/null +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -0,0 +1,726 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BioGPT model.""" + +import math +from functools import partial +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...activations import ACT2FN +from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import ( + BaseModelOutputWithPastAndCrossAttentions, + CausalLMOutputWithCrossAttentions, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...processing_utils import Unpack +from ...utils import ( + LossKwargs, + auto_docstring, + is_torch_flex_attn_available, + logger, + logging, +) +from ..bart.modeling_bart import ( + BartAttention, + BartDecoderLayer, + BartScaledWordEmbedding, +) +from .configuration_biogpt import BioGptConfig + + +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + +class BioGptLearnedPositionalEmbedding(nn.Embedding): + """ + This module learns positional embeddings up to a fixed maximum size. + """ + + def __init__(self, num_embeddings: int, embedding_dim: int): + # BioGpt is set up so that if padding_idx is specified then offset the embedding ids by 2 + # and adjust num_embeddings appropriately. Other models don't have this hack + self.offset = 2 + super().__init__(num_embeddings + self.offset, embedding_dim) + + def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + """`input_ids_shape` is expected to be [bsz x seqlen].""" + attention_mask = attention_mask.long() + + # create positions depending on attention_mask + positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 + + # cut positions if `past_key_values_length` is > 0 + positions = positions[:, past_key_values_length:] + + return super().forward(positions + self.offset) + + +class BioGptScaledWordEmbedding(BartScaledWordEmbedding): + pass + + +class BioGptAttention(BartAttention): + pass + + +class BioGptDecoderLayer(BartDecoderLayer): + def __init__(self, config: BioGptConfig): + super().__init__(config) + self.embed_dim = config.hidden_size + + self.self_attn = BioGptAttention( + embed_dim=self.embed_dim, + num_heads=config.num_attention_heads, + dropout=config.attention_probs_dropout_prob, + is_decoder=True, + is_causal=True, + config=config, + ) + self.dropout = config.hidden_dropout_prob + self.activation_fn = ACT2FN[config.hidden_act] + + self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim) + + del self.encoder_attn + del self.encoder_attn_layer_norm + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + residual = hidden_states + + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + **flash_attn_kwargs, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +@auto_docstring +class BioGptPreTrainedModel(PreTrainedModel): + config_class = BioGptConfig + base_model_prefix = "biogpt" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +@auto_docstring +class BioGptModel(BioGptPreTrainedModel): + def __init__(self, config: BioGptConfig): + super().__init__(config) + self.config = config + self.layerdrop = config.layerdrop + self.dropout = config.hidden_dropout_prob + self.embed_dim = config.hidden_size + self.padding_idx = config.pad_token_id + embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0 + + self.embed_tokens = BioGptScaledWordEmbedding( + config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale + ) + self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim) + + self.layers = nn.ModuleList([BioGptDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layer_norm = nn.LayerNorm(self.embed_dim) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **flash_attn_kwargs: Unpack[FlashAttentionKwargs], + ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input = input_ids + input_shape = input.size() + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # past_key_values_length + past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input) + + if attention_mask is None: + attention_mask = torch.ones( + (inputs_embeds.shape[0], inputs_embeds.shape[1] + past_key_values_length), + dtype=torch.bool, + device=inputs_embeds.device, + ) + elif attention_mask.shape[1] != past_key_values_length + input_shape[1]: + raise ValueError( + f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " + f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)" + ) + + # embed positions + positions = self.embed_positions(attention_mask, past_key_values_length) + + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + + hidden_states = inputs_embeds + positions + + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + all_cross_attentions = None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + if output_hidden_states: + all_hidden_states += (hidden_states,) + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: + continue + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + partial(decoder_layer.__call__, **flash_attn_kwargs), + hidden_states, + attention_mask, + head_mask[idx] if head_mask is not None else None, + None, + output_attentions, + use_cache, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + **flash_attn_kwargs, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + hidden_states = self.layer_norm(hidden_states) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple( + v + for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + cross_attentions=all_cross_attentions, + ) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... + + +@auto_docstring( + custom_intro=""" + BioGPT Model with a `language modeling` head on top for CLM fine-tuning. + """ +) +class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["output_projection.weight"] + + def __init__(self, config): + super().__init__(config) + + self.biogpt = BioGptModel(config) + self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_output_embeddings(self): + return self.output_projection + + def set_output_embeddings(self, new_embeddings): + self.output_projection = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs: Unpack[KwargsForCausalLM], + ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set + `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` + are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + outputs = self.biogpt( + input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + past_key_values=past_key_values, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + sequence_output = outputs[0] + prediction_scores = self.output_projection(sequence_output) + + lm_loss = None + if labels is not None: + lm_loss = self.loss_function( + prediction_scores, + labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (prediction_scores,) + outputs[1:] + return ((lm_loss,) + output) if lm_loss is not None else output + + return CausalLMOutputWithCrossAttentions( + loss=lm_loss, + logits=prediction_scores, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + cross_attentions=outputs.cross_attentions, + ) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +@auto_docstring +class BioGptForTokenClassification(BioGptPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + + self.biogpt = BioGptModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + else: + classifier_dropout = config.hidden_dropout_prob + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, TokenClassifierOutput]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.biogpt( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if attention_mask is not None: + active_loss = attention_mask.view(-1) == 1 + active_logits = logits.view(-1, self.num_labels) + active_labels = torch.where( + active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels) + ) + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) + + if not return_dict: + output = (logits,) + transformer_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The BioGpt Model transformer with a sequence classification head on top (linear layer). + + [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it is required to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """ +) +class BioGptForSequenceClassification(BioGptPreTrainedModel): + def __init__(self, config: BioGptConfig): + super().__init__(config) + self.num_labels = config.num_labels + self.biogpt = BioGptModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r""" + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.biogpt( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size, sequence_length = input_ids.shape[:2] + else: + batch_size, sequence_length = inputs_embeds.shape[:2] + + if self.config.pad_token_id is None: + sequence_length = -1 + else: + if input_ids is not None: + sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device) + else: + sequence_length = -1 + logger.warning_once( + f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be " + "unexpected if using padding tokens in conjunction with `inputs_embeds.`" + ) + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length] + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + def get_input_embeddings(self): + return self.biogpt.embed_tokens + + def set_input_embeddings(self, value): + self.biogpt.embed_tokens = value + + +__all__ = [ + "BioGptForCausalLM", + "BioGptForTokenClassification", + "BioGptForSequenceClassification", + "BioGptModel", + "BioGptPreTrainedModel", +] From 59cf07d9728be48647201098472fd04ffca59e63 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 13 May 2025 18:29:57 +0200 Subject: [PATCH 10/68] first batch of other models --- src/transformers/models/bart/modeling_bart.py | 8 +- .../modeling_bigbird_pegasus.py | 148 +-- .../models/biogpt/modeling_biogpt.py | 4 + .../models/biogpt/modular_biogpt.py | 2 - .../models/blenderbot/modeling_blenderbot.py | 318 +++++-- .../modeling_blenderbot_small.py | 278 ++++-- .../data2vec/modeling_data2vec_audio.py | 425 +++------ .../models/informer/modeling_informer.py | 491 ++++++---- .../models/informer/modular_informer.py | 858 ++++++++++++++++++ .../models/mbart/modeling_mbart.py | 567 +++++------- .../modeling_time_series_transformer.py | 320 +++++-- .../models/wav2vec2/modeling_wav2vec2.py | 484 ++++------ 12 files changed, 2372 insertions(+), 1531 deletions(-) create mode 100644 src/transformers/models/informer/modular_informer.py diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index cd060dc01cb3..b3612f69f473 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -44,7 +44,7 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import LossKwargs, auto_docstring, is_torch_flex_attn_available, logging +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_bart import BartConfig @@ -167,6 +167,10 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def forward( self, hidden_states: torch.Tensor, @@ -886,7 +890,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, input_shape, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index efcf2dc44360..031a73145f99 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -26,6 +26,7 @@ from ...activations import ACT2FN from ...generation import GenerationMixin from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -35,7 +36,8 @@ Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import auto_docstring, logging from .configuration_bigbird_pegasus import BigBirdPegasusConfig @@ -1165,6 +1167,34 @@ def forward( return outputs +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with BartConfig->BigBirdPegasusConfig, Bart->BigBirdPegasusDecoder class BigBirdPegasusDecoderAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1200,6 +1230,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -1211,6 +1242,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -1218,10 +1252,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -1236,18 +1274,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -1259,69 +1297,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class BigBirdPegasusEncoderLayer(nn.Module): @@ -1421,6 +1431,7 @@ def __init__(self, config: BigBirdPegasusConfig): dropout=config.attention_dropout, is_decoder=True, bias=config.use_bias, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -1433,6 +1444,7 @@ def __init__(self, config: BigBirdPegasusConfig): dropout=config.attention_dropout, is_decoder=True, bias=config.use_bias, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 284c6b7c33b1..1aadcd7715ac 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -148,6 +148,10 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 018d1143c34c..0afa84d4c07b 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -21,7 +21,6 @@ import torch import torch.nn as nn import torch.utils.checkpoint -from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN @@ -44,7 +43,6 @@ auto_docstring, is_torch_flex_attn_available, logger, - logging, ) from ..bart.modeling_bart import ( BartAttention, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 1eef92764281..8a4e09841683 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -18,7 +18,7 @@ import math import os import warnings -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -27,7 +27,13 @@ from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -35,12 +41,17 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ..blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel from .configuration_blenderbot import BlenderbotConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -92,6 +103,34 @@ def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot class BlenderbotAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -127,6 +166,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -138,6 +178,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -145,10 +188,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -163,18 +210,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -186,72 +233,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value - - -BLENDERBOT_ATTENTION_CLASSES = {"eager": BlenderbotAttention} + return attn_output, attn_weights, past_key_value # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT @@ -260,7 +276,7 @@ def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BlenderbotAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -329,7 +345,7 @@ def __init__(self, config: BlenderbotConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BlenderbotAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -342,7 +358,7 @@ def __init__(self, config: BlenderbotConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = BlenderbotAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -449,6 +465,9 @@ class BlenderbotPreTrainedModel(PreTrainedModel): config_class = BlenderbotConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): std = self.config.init_std @@ -584,10 +603,12 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + _unsupported_features, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -645,6 +666,34 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class BlenderbotDecoder(BlenderbotPreTrainedModel): """ @@ -792,16 +841,21 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) @@ -901,6 +955,92 @@ def forward( cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class BlenderbotModel(BlenderbotPreTrainedModel): diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index eefb467e6059..0b5488d5c00e 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -25,7 +25,13 @@ from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -33,11 +39,16 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_blenderbot_small import BlenderbotSmallConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -76,6 +87,34 @@ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): return super().forward(positions) +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BlenderbotSmall class BlenderbotSmallAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -111,6 +150,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -122,6 +162,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -129,10 +172,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -147,18 +194,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -170,69 +217,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL @@ -241,7 +260,7 @@ def __init__(self, config: BlenderbotSmallConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BlenderbotSmallAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -306,19 +325,13 @@ def forward( return outputs -# TODO: Implement attention with SDPA for TimeSeriesTransformer. -BLENDERBOT_SMALL_ATTENTION_CLASSES = { - "eager": BlenderbotSmallAttention, -} - - # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL class BlenderbotSmallDecoderLayer(nn.Module): def __init__(self, config: BlenderbotSmallConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = BlenderbotSmallAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -331,7 +344,7 @@ def __init__(self, config: BlenderbotSmallConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = BlenderbotSmallAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -438,6 +451,9 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): config_class = BlenderbotSmallConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): std = self.config.init_std @@ -572,10 +588,12 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + _unsupported_features, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -630,6 +648,34 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): """ @@ -778,12 +824,21 @@ def forward( attention_mask, input_shape, inputs_embeds, past_key_values_length ) - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, + ) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) @@ -883,6 +938,51 @@ def forward( cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + @auto_docstring class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 3db2fc0c685f..5b908f0c2612 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -6,7 +6,7 @@ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import warnings -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -16,7 +16,8 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -25,13 +26,14 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, is_peft_available, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging from .configuration_data2vec_audio import Data2VecAudioConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -167,6 +169,33 @@ def forward(self, hidden_states): return hidden_states, norm_hidden_states +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class Data2VecAudioAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -201,6 +230,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -212,6 +242,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -219,10 +252,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -237,18 +274,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -260,303 +297,43 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class Data2VecAudioFlashAttention2(Data2VecAudioAttention): - """ - Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # Data2VecAudioFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("Data2VecAudioFlashAttention2 attention does not support output_attentions") - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class Data2VecAudioSdpaAttention(Data2VecAudioAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Data2VecAudioModel is using Data2VecAudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - class Data2VecAudioFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -581,21 +358,15 @@ def forward(self, hidden_states): return hidden_states -DATA2VEC_AUDIO_ATTENTION_CLASSES = { - "eager": Data2VecAudioAttention, - "sdpa": Data2VecAudioSdpaAttention, - "flash_attention_2": Data2VecAudioFlashAttention2, -} - - class Data2VecAudioEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = DATA2VEC_AUDIO_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = Data2VecAudioAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -632,7 +403,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -649,16 +419,13 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + + _unsupported_features = output_attentions is True + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + _unsupported_features, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -707,6 +474,33 @@ def forward( attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class Data2VecAudioAdapterLayer(nn.Module): def __init__(self, config): @@ -765,6 +559,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 98acd1bd13be..233874521de7 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1,27 +1,23 @@ -# coding=utf-8 -# Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch Informer model.""" - -from typing import List, Optional, Tuple, Union +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/informer/modular_informer.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_informer.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch from torch import nn from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -29,19 +25,20 @@ Seq2SeqTSModelOutput, Seq2SeqTSPredictionOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput -from ...utils import ( - auto_docstring, - logging, -) +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_informer import InformerConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesFeatureEmbedder with TimeSeries->Informer class InformerFeatureEmbedder(nn.Module): """ Embed a sequence of categorical features. @@ -76,7 +73,6 @@ def forward(self, features: torch.Tensor) -> torch.Tensor: ) -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer class InformerStdScaler(nn.Module): """ Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by @@ -112,7 +108,6 @@ def forward( return (data - loc) / scale, loc, scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer class InformerMeanScaler(nn.Module): """ Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data @@ -167,7 +162,6 @@ def forward( return scaled_data, torch.zeros_like(scale), scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer class InformerNOPScaler(nn.Module): """ Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data. @@ -195,40 +189,6 @@ def forward( return data, loc, scale -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average -def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: - """ - Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, - meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. - - Args: - input_tensor (`torch.FloatTensor`): - Input tensor, of which the average must be computed. - weights (`torch.FloatTensor`, *optional*): - Weights tensor, of the same shape as `input_tensor`. - dim (`int`, *optional*): - The dim along which to average `input_tensor`. - - Returns: - `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. - """ - if weights is not None: - weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) - sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) - return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights - else: - return input_tensor.mean(dim=dim) - - -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll -def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: - """ - Computes the negative log likelihood loss from input distribution with respect to target. - """ - return -input.log_prob(target) - - -# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Informer class InformerSinusoidalPositionalEmbedding(nn.Embedding): """This module produces sinusoidal positional embeddings of any length.""" @@ -260,7 +220,6 @@ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) return super().forward(positions) -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesValueEmbedding with TimeSeries->Info class InformerValueEmbedding(nn.Module): def __init__(self, feature_size, d_model): super().__init__() @@ -270,7 +229,54 @@ def forward(self, x): return self.value_projection(x) -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Informer +@auto_docstring +class InformerPreTrainedModel(PreTrainedModel): + config_class = InformerConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, InformerSinusoidalPositionalEmbedding): + module._init_weight() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class InformerAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -305,6 +311,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -316,6 +323,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -323,10 +333,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -341,18 +355,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -364,69 +378,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class InformerProbSparseAttention(nn.Module): @@ -675,6 +661,14 @@ class InformerEncoderLayer(nn.Module): def __init__(self, config: InformerConfig): super().__init__() self.embed_dim = config.d_model + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) + self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + if config.attention_type == "prob": self.self_attn = InformerProbSparseAttention( embed_dim=self.embed_dim, @@ -687,14 +681,8 @@ def __init__(self, config: InformerConfig): embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) def forward( self, @@ -751,22 +739,6 @@ class InformerDecoderLayer(nn.Module): def __init__(self, config: InformerConfig): super().__init__() self.embed_dim = config.d_model - - if config.attention_type == "prob": - self.self_attn = InformerProbSparseAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - sampling_factor=config.sampling_factor, - is_decoder=True, - ) - else: - self.self_attn = InformerAttention( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout @@ -777,12 +749,28 @@ def __init__(self, config: InformerConfig): config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + def forward( self, hidden_states: torch.Tensor, @@ -873,31 +861,10 @@ def forward( return outputs -@auto_docstring -class InformerPreTrainedModel(PreTrainedModel): - config_class = InformerConfig - base_model_prefix = "model" - main_input_name = "past_values" - supports_gradient_checkpointing = True - - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, (nn.Linear, nn.Conv1d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, InformerSinusoidalPositionalEmbedding): - module._init_weight() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - class InformerEncoder(InformerPreTrainedModel): """ - Informer encoder consisting of *config.encoder_layers* self attention layers with distillation layers. Each - attention layer is an [`InformerEncoderLayer`]. + Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a + [`InformerEncoderLayer`]. Args: config: InformerConfig @@ -908,7 +875,6 @@ def __init__(self, config: InformerConfig): self.dropout = config.dropout self.layerdrop = config.encoder_layerdrop - self.gradient_checkpointing = False if config.prediction_length is None: raise ValueError("The `prediction_length` config needs to be specified.") @@ -918,6 +884,7 @@ def __init__(self, config: InformerConfig): ) self.layers = nn.ModuleList([InformerEncoderLayer(config) for _ in range(config.encoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) + self.gradient_checkpointing = False if config.distil: self.conv_layers = nn.ModuleList( @@ -926,7 +893,6 @@ def __init__(self, config: InformerConfig): self.conv_layers.append(None) else: self.conv_layers = [None] * config.encoder_layers - # Initialize weights and apply final processing self.post_init() @@ -1044,11 +1010,37 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerDecoder with TimeSeriesTransformer->Informer,TimeSeriesTransformerConfig->InformerConfig,time-series-transformer->informer,Transformer->Informer,TimeSeries->Informer class InformerDecoder(InformerPreTrainedModel): """ - Informer decoder consisting of *config.decoder_layers* layers. Each layer is a + Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`InformerDecoderLayer`] Args: @@ -1156,16 +1148,21 @@ def forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) hidden_states = self.value_embedding(inputs_embeds) embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length) @@ -1262,9 +1259,92 @@ def forward( cross_attentions=all_cross_attentions, ) + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerModel with TimeSeriesTransformer->Informer,TIME_SERIES_TRANSFORMER->INFORMER,time-series-transformer->informer,TimeSeries->Informer class InformerModel(InformerPreTrainedModel): def __init__(self, config: InformerConfig): super().__init__(config) @@ -1612,11 +1692,42 @@ def forward( ) +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + +def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: + """ + Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, + meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`. + + Args: + input_tensor (`torch.FloatTensor`): + Input tensor, of which the average must be computed. + weights (`torch.FloatTensor`, *optional*): + Weights tensor, of the same shape as `input_tensor`. + dim (`int`, *optional*): + The dim along which to average `input_tensor`. + + Returns: + `torch.FloatTensor`: The tensor with values averaged along the specified `dim`. + """ + if weights is not None: + weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor)) + sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0) + return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights + else: + return input_tensor.mean(dim=dim) + + @auto_docstring -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerForPrediction with TimeSeriesTransformer->Informer,TIME_SERIES_TRANSFORMER->INFORMER,time-series-transformer->informer class InformerForPrediction(InformerPreTrainedModel): def __init__(self, config: InformerConfig): super().__init__(config) + self.model = InformerModel(config) if config.distribution_output == "student_t": self.distribution_output = StudentTOutput(dim=config.input_size) diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py new file mode 100644 index 000000000000..edd12c192dc8 --- /dev/null +++ b/src/transformers/models/informer/modular_informer.py @@ -0,0 +1,858 @@ +from typing import Optional, Tuple, Union + +import numpy as np +import torch +from torch import nn + +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_outputs import ( + BaseModelOutput, +) +from ...modeling_utils import PreTrainedModel +from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput +from ...utils import ( + auto_docstring, +) +from ..bart.modeling_bart import BartAttention +from ..time_series_transformer.modeling_time_series_transformer import ( + TimeSeriesFeatureEmbedder, + TimeSeriesMeanScaler, + TimeSeriesNOPScaler, + TimeSeriesSinusoidalPositionalEmbedding, + TimeSeriesStdScaler, + TimeSeriesTransformerDecoder, + TimeSeriesTransformerDecoderLayer, + TimeSeriesTransformerEncoder, + TimeSeriesTransformerEncoderLayer, + TimeSeriesTransformerForPrediction, + TimeSeriesTransformerModel, + TimeSeriesValueEmbedding, +) +from .configuration_informer import InformerConfig + + +class InformerFeatureEmbedder(TimeSeriesFeatureEmbedder): + pass + + +class InformerStdScaler(TimeSeriesStdScaler): + pass + + +class InformerMeanScaler(TimeSeriesMeanScaler): + pass + + +class InformerNOPScaler(TimeSeriesNOPScaler): + pass + + +class InformerSinusoidalPositionalEmbedding(TimeSeriesSinusoidalPositionalEmbedding): + pass + + +class InformerValueEmbedding(TimeSeriesValueEmbedding): + pass + + +@auto_docstring +class InformerPreTrainedModel(PreTrainedModel): + config_class = InformerConfig + base_model_prefix = "model" + main_input_name = "past_values" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, (nn.Linear, nn.Conv1d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, InformerSinusoidalPositionalEmbedding): + module._init_weight() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class InformerAttention(BartAttention): + pass + +class InformerProbSparseAttention(nn.Module): + """Probabilistic Attention mechanism to select the "active" + queries rather than the "lazy" queries and provides a sparse Transformer thus mitigating the quadratic compute and + memory requirements of vanilla attention""" + + def __init__( + self, + embed_dim: int, + num_heads: int, + dropout: float = 0.0, + is_decoder: bool = False, + sampling_factor: int = 5, + bias: bool = True, + ): + super().__init__() + self.factor = sampling_factor + self.embed_dim = embed_dim + self.num_heads = num_heads + self.dropout = dropout + self.head_dim = embed_dim // num_heads + + if (self.head_dim * num_heads) != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}" + f" and `num_heads`: {num_heads})." + ) + self.scaling = self.head_dim**-0.5 + self.is_decoder = is_decoder + + self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + # if key_value_states are provided this layer is used as a cross-attention layer + # for the decoder + is_cross_attention = key_value_states is not None + + bsz, tgt_len, _ = hidden_states.size() + + # get query proj + query_states = self.q_proj(hidden_states) * self.scaling + # get key, value proj + # `past_key_value[0].shape[2] == key_value_states.shape[1]` + # is checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + if ( + is_cross_attention + and past_key_value is not None + and past_key_value[0].shape[2] == key_value_states.shape[1] + ): + # reuse k,v, cross_attentions + key_states = past_key_value[0] + value_states = past_key_value[1] + elif is_cross_attention: + # cross_attentions + key_states = self._shape(self.k_proj(key_value_states), -1, bsz) + value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + elif past_key_value is not None: + # reuse k, v, self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + else: + # self_attention + key_states = self._shape(self.k_proj(hidden_states), -1, bsz) + value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) + + proj_shape = (bsz * self.num_heads, -1, self.head_dim) + query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + key_states = key_states.reshape(*proj_shape) + value_states = value_states.reshape(*proj_shape) + + key_states_time_length = key_states.size(1) # L_K + log_key_states_time_length = np.ceil(np.log1p(key_states_time_length)).astype("int").item() # log_L_K + + query_states_time_length = query_states.size(1) # L_Q + log_query_states_time_length = np.ceil(np.log1p(query_states_time_length)).astype("int").item() # log_L_Q + + u_part = min(self.factor * query_states_time_length * log_key_states_time_length, key_states_time_length) + u = min(self.factor * log_query_states_time_length, query_states_time_length) + + if key_states_time_length > 0: + index_sample = torch.randint(0, key_states_time_length, (u_part,)) + k_sample = key_states[:, index_sample, :] + else: + k_sample = key_states + + queries_keys_sample = torch.bmm(query_states, k_sample.transpose(1, 2)) # Q_K_sampled + + # find the Top_k query with sparsity measurement + if u > 0: + sparsity_measurement = queries_keys_sample.max(dim=-1)[0] - torch.div( + queries_keys_sample.sum(dim=-1), key_states_time_length + ) # M + top_u_sparsity_measurement = sparsity_measurement.topk(u, sorted=False)[1] # M_top + + # calculate q_reduce: query_states[:, top_u_sparsity_measurement] + dim_for_slice = torch.arange(query_states.size(0)).unsqueeze(-1) + q_reduce = query_states[dim_for_slice, top_u_sparsity_measurement] + else: + q_reduce = query_states + top_u_sparsity_measurement = None + + # Use q_reduce to calculate attention weights + attn_weights = torch.bmm(q_reduce, key_states.transpose(1, 2)) + + src_len = key_states.size(1) + if attn_weights.size() != (bsz * self.num_heads, u, src_len): + raise ValueError( + f"Attention weights should be of size {(bsz * self.num_heads, u, src_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, tgt_len, src_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + ) + prob_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, src_len).reshape( + bsz * self.num_heads, tgt_len, src_len + ) + + if top_u_sparsity_measurement is not None: + dim_for_slice = torch.arange(prob_mask.size(0)).unsqueeze(-1) + prob_mask = prob_mask[dim_for_slice, top_u_sparsity_measurement, :] + + attn_weights = attn_weights.view(bsz, self.num_heads, u, src_len) + prob_mask.view( + bsz, self.num_heads, u, src_len + ) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if layer_head_mask is not None: + if layer_head_mask.size() != (self.num_heads,): + raise ValueError( + f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" + f" {layer_head_mask.size()}" + ) + attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len) + + if output_attentions: + # this operation is a bit awkward, but it's required to + # make sure that attn_weights keeps its gradient. + # In order to do so, attn_weights have to be reshaped + # twice and have to be reused in the following + attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, u, src_len) + attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, u, src_len) + else: + attn_weights_reshaped = None + + attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.bmm(attn_probs, value_states) + + # calculate context for updating the attn_output, based on: + # https://github.com/zhouhaoyi/Informer2020/blob/ac59c7447135473fb2aafeafe94395f884d5c7a5/models/attn.py#L74 + if self.is_decoder: + # cast to float32 before operation to avoid overflow + context = value_states.cumsum(dim=-2, dtype=torch.float32).to(value_states.dtype) + else: + v_mean_dim_time = value_states.mean(dim=-2) + context = ( + v_mean_dim_time.unsqueeze(dim=1) + .expand(bsz * self.num_heads, query_states_time_length, v_mean_dim_time.size(-1)) + .clone() + ) + + if top_u_sparsity_measurement is not None: + # update context: copy the attention output to the context at top_u_sparsity_measurement index + dim_for_slice = torch.arange(context.size(0)).unsqueeze(-1) + context[dim_for_slice, top_u_sparsity_measurement, :] = attn_output + attn_output = context + + if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) + attn_output = attn_output.transpose(1, 2) + + # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be + # partitioned across GPUs when using tensor-parallelism. + attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights_reshaped, past_key_value + + +# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py +class InformerConvLayer(nn.Module): + def __init__(self, c_in): + super().__init__() + self.downConv = nn.Conv1d( + in_channels=c_in, + out_channels=c_in, + kernel_size=3, + padding=1, + padding_mode="circular", + ) + self.norm = nn.BatchNorm1d(c_in) + self.activation = nn.ELU() + self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1) + + def forward(self, x): + x = self.downConv(x.permute(0, 2, 1)) + x = self.norm(x) + x = self.activation(x) + x = self.maxPool(x) + x = x.transpose(1, 2) + return x + + +class InformerEncoderLayer(TimeSeriesTransformerEncoderLayer): + def __init__(self, config: InformerConfig): + super().__init__(config) + + del self.self_attn + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class InformerDecoderLayer(TimeSeriesTransformerDecoderLayer): + def __init__(self, config: InformerConfig): + super().__init__(config) + + del self.self_attn + + if config.attention_type == "prob": + self.self_attn = InformerProbSparseAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + sampling_factor=config.sampling_factor, + ) + else: + self.self_attn = InformerAttention( + embed_dim=self.embed_dim, + num_heads=config.encoder_attention_heads, + dropout=config.attention_dropout, + config=config, + ) + + +class InformerEncoder(TimeSeriesTransformerEncoder): + def __init__(self, config: InformerConfig): + super().__init__(config) + + self.dropout = config.dropout + self.layerdrop = config.encoder_layerdrop + self.gradient_checkpointing = False + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + if config.distil: + self.conv_layers = nn.ModuleList( + [InformerConvLayer(config.d_model) for _ in range(config.encoder_layers - 1)] + ) + self.conv_layers.append(None) + else: + self.conv_layers = [None] * config.encoder_layers + + # Initialize weights and apply final processing + self.post_init() + + def forward( + self, + attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + hidden_states = self.value_embedding(inputs_embeds) + embed_pos = self.embed_positions(inputs_embeds.size()) + + hidden_states = self.layernorm_embedding(hidden_states + embed_pos) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + + # expand attention_mask + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + # check if head_mask has a correct number of layers specified if desired + if head_mask is not None: + if head_mask.size()[0] != (len(self.layers)): + raise ValueError( + f"The head_mask should be specified for {len(self.layers)} layers, but it is for" + f" {head_mask.size()[0]}." + ) + + for idx, (encoder_layer, conv_layer) in enumerate(zip(self.layers, self.conv_layers)): + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + to_drop = False + if self.training: + dropout_probability = torch.rand([]) + if dropout_probability < self.layerdrop: # skip the layer + to_drop = True + + if to_drop: + layer_outputs = (None, None) + else: + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + (head_mask[idx] if head_mask is not None else None), + output_attentions, + ) + if conv_layer is not None: + output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + layer_head_mask=(head_mask[idx] if head_mask is not None else None), + output_attentions=output_attentions, + ) + if conv_layer is not None: + output = conv_layer(layer_outputs[0]) + layer_outputs = (output,) + layer_outputs[1:] + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions + ) + + +class InformerDecoder(TimeSeriesTransformerDecoder): + def __init__(self, config: InformerConfig): + super().__init__(config) + self.dropout = config.dropout + self.layerdrop = config.decoder_layerdrop + if config.prediction_length is None: + raise ValueError("The `prediction_length` config needs to be specified.") + + self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model) + self.embed_positions = InformerSinusoidalPositionalEmbedding( + config.context_length + config.prediction_length, config.d_model + ) + self.layers = nn.ModuleList([InformerDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layernorm_embedding = nn.LayerNorm(config.d_model) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + +class InformerModel(TimeSeriesTransformerModel, nn.Module): + def __init__(self, config: InformerConfig): + nn.Module().__init__(config) + + if config.scaling == "mean" or config.scaling is True: + self.scaler = InformerMeanScaler(config) + elif config.scaling == "std": + self.scaler = InformerStdScaler(config) + else: + self.scaler = InformerNOPScaler(config) + + if config.num_static_categorical_features > 0: + self.embedder = InformerFeatureEmbedder( + cardinalities=config.cardinality, + embedding_dims=config.embedding_dimension, + ) + + # transformer encoder-decoder and mask initializer + self.encoder = InformerEncoder(config) + self.decoder = InformerDecoder(config) + + # Initialize weights and apply final processing + self.post_init() + + def forward(self, **super_kwargs): + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerModel + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerModel.from_pretrained("huggingface/informer-tourism-monthly") + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> last_hidden_state = outputs.last_hidden_state + ```""" + super().forward(**super_kwargs) + + +class InformerForPrediction(TimeSeriesTransformerForPrediction, nn.Module): + def __init__(self, config: InformerConfig): + nn.Module().__init__(config) + + self.model = InformerModel(config) + if config.distribution_output == "student_t": + self.distribution_output = StudentTOutput(dim=config.input_size) + elif config.distribution_output == "normal": + self.distribution_output = NormalOutput(dim=config.input_size) + elif config.distribution_output == "negative_binomial": + self.distribution_output = NegativeBinomialOutput(dim=config.input_size) + else: + raise ValueError(f"Unknown distribution output {config.distribution_output}") + + self.parameter_projection = self.distribution_output.get_parameter_projection(self.model.config.d_model) + self.target_shape = self.distribution_output.event_shape + + if config.loss == "nll": + self.loss = nll + else: + raise ValueError(f"Unknown loss function {config.loss}") + + # Initialize weights of distribution_output and apply final processing + self.post_init() + + @auto_docstring + def forward(self, **super_kwargs): + r""" + past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): + Past values of the time series, that serve as context in order to predict the future. The sequence size of + this tensor must be larger than the `context_length` of the model, since the model will use the larger size + to construct lag features, i.e. additional values from the past which are added in order to serve as "extra + context". + + The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no + `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest + look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of + the past. + + The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as + `static_categorical_features`, `static_real_features`, `past_time_features` and lags). + + Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`): + Required time features, which the model internally will add to `past_values`. These could be things like + "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These + could also be so-called "age" features, which basically help the model know "at which point in life" a + time-series is. Age features have small values for distant past time steps and increase monotonically the + more we approach the current time step. Holiday features are also a good example of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in + `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*): + Optional static categorical features for which the model will learn an embedding, which it will add to the + values of the time series. + + Static categorical features are features which have the same value for all time steps (static over time). + + A typical example of a static categorical feature is a time series ID. + static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*): + Optional static real features which the model will add to the values of the time series. + + Static real features are features which have the same value for all time steps (static over time). + + A typical example of a static real feature is promotion information. + future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*): + Future values of the time series, that serve as labels for the model. The `future_values` is what the + Transformer needs during training to learn to output, given the `past_values`. + + The sequence length here is equal to `prediction_length`. + + See the demo notebook and code snippets for details. + + Optionally, during training any missing values need to be replaced with zeros and indicated via the + `future_observed_mask`. + + For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of + variates in the time series per time step. + future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`): + Required time features for the prediction window, which the model internally will add to `future_values`. + These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as + Fourier features). These could also be so-called "age" features, which basically help the model know "at + which point in life" a time-series is. Age features have small values for distant past time steps and + increase monotonically the more we approach the current time step. Holiday features are also a good example + of time features. + + These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where + the position encodings are learned from scratch internally as parameters of the model, the Time Series + Transformer requires to provide additional time features. The Time Series Transformer only learns + additional embeddings for `static_categorical_features`. + + Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features + must but known at prediction time. + + The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`. + future_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*): + Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected + in `[0, 1]`: + + - 1 for values that are **observed**, + - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros). + + This mask is used to filter out missing values for the final loss calculation. + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*): + Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*) + `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of + hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder. + + Examples: + + ```python + >>> from huggingface_hub import hf_hub_download + >>> import torch + >>> from transformers import InformerForPrediction + + >>> file = hf_hub_download( + ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset" + ... ) + >>> batch = torch.load(file) + + >>> model = InformerForPrediction.from_pretrained( + ... "huggingface/informer-tourism-monthly" + ... ) + + >>> # during training, one provides both past and future values + >>> # as well as possible additional features + >>> outputs = model( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_values=batch["future_values"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> loss = outputs.loss + >>> loss.backward() + + >>> # during inference, one only provides past values + >>> # as well as possible additional features + >>> # the model autoregressively generates future values + >>> outputs = model.generate( + ... past_values=batch["past_values"], + ... past_time_features=batch["past_time_features"], + ... past_observed_mask=batch["past_observed_mask"], + ... static_categorical_features=batch["static_categorical_features"], + ... static_real_features=batch["static_real_features"], + ... future_time_features=batch["future_time_features"], + ... ) + + >>> mean_prediction = outputs.sequences.mean(dim=1) + ```""" + super().forward(**super_kwargs) + + +__all__ = ["InformerForPrediction", "InformerModel", "InformerPreTrainedModel"] diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 560b8eca9bd2..013fd6883df7 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -31,7 +31,7 @@ _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -41,13 +41,14 @@ Seq2SeqQuestionAnsweringModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_mbart import MBartConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -110,6 +111,34 @@ def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart class MBartAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -145,6 +174,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -156,6 +186,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -163,10 +196,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -181,18 +218,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -204,318 +241,50 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MBart -class MBartFlashAttention2(MBartAttention): - """ - MBart flash attention module. This module inherits from `MBartAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # MBartFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("MBartFlashAttention2 attention does not support output_attentions") - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MBart -class MBartSdpaAttention(MBartAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MBartModel is using MBartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -MBART_ATTENTION_CLASSES = { - "eager": MBartAttention, - "sdpa": MBartSdpaAttention, - "flash_attention_2": MBartFlashAttention2, -} - class MBartEncoderLayer(nn.Module): def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = MBartAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -583,7 +352,7 @@ def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = MBartAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -596,7 +365,7 @@ def __init__(self, config: MBartConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = MBartAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -731,6 +500,7 @@ class MBartPreTrainedModel(PreTrainedModel): _no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): std = self.config.init_std @@ -873,18 +643,12 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2": - attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + _unsupported_features, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -941,6 +705,34 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class MBartDecoder(MBartPreTrainedModel): """ @@ -1092,42 +884,21 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self.config._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2": - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, + ) # embed positions positions = self.embed_positions(input, past_key_values_length) @@ -1228,6 +999,92 @@ def forward( cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class MBartModel(MBartPreTrainedModel): diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 6ba0d88e94f1..8f5b87e648f3 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -15,14 +15,20 @@ # limitations under the License. """PyTorch Time Series Transformer model.""" -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch from torch import nn from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -30,15 +36,21 @@ Seq2SeqTSModelOutput, Seq2SeqTSPredictionOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import ( auto_docstring, + is_torch_flex_attn_available, logging, ) from .configuration_time_series_transformer import TimeSeriesTransformerConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -264,6 +276,34 @@ def forward(self, x): return self.value_projection(x) +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->TimeSeriesTransformer class TimeSeriesTransformerAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -299,6 +339,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -310,6 +351,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -317,10 +361,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -335,18 +383,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -358,69 +406,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer, BART->TIME_SERIES_TRANSFORMER @@ -429,7 +449,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = TimeSeriesTransformerAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -494,19 +514,13 @@ def forward( return outputs -# TODO: Implement attention with SDPA for TimeSeriesTransformer. -TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = { - "eager": TimeSeriesTransformerAttention, -} - - # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER class TimeSeriesTransformerDecoderLayer(nn.Module): def __init__(self, config: TimeSeriesTransformerConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = TimeSeriesTransformerAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -519,7 +533,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = TimeSeriesTransformerAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -627,6 +641,9 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "past_values" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): std = self.config.init_std @@ -719,10 +736,12 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states + embed_pos) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + _unsupported_features, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -778,6 +797,34 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): """ @@ -889,16 +936,21 @@ def forward( # past_key_values_length past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) hidden_states = self.value_embedding(inputs_embeds) embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length) @@ -995,6 +1047,92 @@ def forward( cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel): diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 8b328fcd1a7c..870a6f4eeaa5 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -17,7 +17,7 @@ import math import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -28,7 +28,11 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -38,7 +42,8 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( ModelOutput, auto_docstring, @@ -46,6 +51,7 @@ check_torch_load_is_safe, is_peft_available, is_safetensors_available, + is_torch_flex_attn_available, logging, ) from .configuration_wav2vec2 import Wav2Vec2Config @@ -58,8 +64,8 @@ from safetensors.torch import load_file as safe_load_file -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -465,6 +471,34 @@ def forward(self, hidden_states): return hidden_states, norm_hidden_states +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2 class Wav2Vec2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -500,6 +534,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -511,6 +546,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -518,10 +556,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -536,18 +578,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -559,312 +601,43 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Wav2Vec2 -class Wav2Vec2FlashAttention2(Wav2Vec2Attention): - """ - Wav2Vec2 flash attention module. This module inherits from `Wav2Vec2Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # Wav2Vec2FlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("Wav2Vec2FlashAttention2 attention does not support output_attentions") - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class Wav2Vec2SdpaAttention(Wav2Vec2Attention): - # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Wav2Vec2 - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "Wav2Vec2Model is using Wav2Vec2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -WAV2VEC2_ATTENTION_CLASSES = { - "eager": Wav2Vec2Attention, - "sdpa": Wav2Vec2SdpaAttention, - "flash_attention_2": Wav2Vec2FlashAttention2, -} - - class Wav2Vec2FeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -892,11 +665,12 @@ def forward(self, hidden_states): class Wav2Vec2EncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = WAV2VEC2_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = Wav2Vec2Attention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -927,11 +701,12 @@ def forward(self, hidden_states, attention_mask=None, output_attentions=False): class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() - self.attention = WAV2VEC2_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = Wav2Vec2Attention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -978,7 +753,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -995,16 +769,13 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + + _unsupported_features = output_attentions is True + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + _unsupported_features, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -1053,6 +824,34 @@ def forward( attentions=all_self_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class Wav2Vec2EncoderStableLayerNorm(nn.Module): def __init__(self, config): @@ -1065,7 +864,6 @@ def __init__(self, config): [Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -1079,19 +877,16 @@ def forward( all_self_attentions = () if output_attentions else None if attention_mask is not None: - # make sure padded tokens are not attended to + # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + hidden_states[~expand_attention_mask] = 0 + + _unsupported_features = output_attentions is True + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + _unsupported_features, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -1142,6 +937,34 @@ def forward( attentions=all_self_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class Wav2Vec2GumbelVectorQuantizer(nn.Module): """ @@ -1301,6 +1124,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" From 146b02b60643ac0c13ac1eda66baa5fd9330bf7b Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 14 May 2025 11:10:54 +0200 Subject: [PATCH 11/68] fix attn dropout --- src/transformers/models/biogpt/modeling_biogpt.py | 4 ++-- src/transformers/models/biogpt/modular_biogpt.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 1aadcd7715ac..eba6dfc498d0 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -536,7 +536,7 @@ def forward( cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask with attention_dropout->attention_probs_dropout_prob def _update_causal_mask( self, attention_mask: Union[torch.Tensor, None], @@ -560,7 +560,7 @@ def _update_causal_mask( elif ( self.config._attn_implementation == "flex_attention" and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) + and (self.config.attention_probs_dropout_prob == 0 or not self.training) ): if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 0afa84d4c07b..351b072fc68e 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -371,7 +371,7 @@ def forward( cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask with attention_dropout->attention_probs_dropout_prob def _update_causal_mask( self, attention_mask: Union[torch.Tensor, None], @@ -395,7 +395,7 @@ def _update_causal_mask( elif ( self.config._attn_implementation == "flex_attention" and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) + and (self.config.attention_probs_dropout_prob == 0 or not self.training) ): if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) From b7f0a2bdf2ec2cd80a0e7f6a66b83d524d3553e7 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 14 May 2025 11:15:15 +0200 Subject: [PATCH 12/68] fix autoformer copies --- .../models/autoformer/modeling_autoformer.py | 49 ++++++++++++++++--- 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 8cb92caa1a4b..c9aae04f4099 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -26,14 +26,21 @@ from torch import nn from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) from ...modeling_outputs import BaseModelOutput, ModelOutput, SampleTSPredictionOutput, Seq2SeqTSPredictionOutput from ...modeling_utils import PreTrainedModel from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput -from ...utils import auto_docstring, logging +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_autoformer import AutoformerConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -980,10 +987,12 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states + embed_pos) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + _unsupported_features, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -1039,6 +1048,34 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class AutoformerDecoder(AutoformerPreTrainedModel): """ From 00c27dfbf5c8714c6af271aeb8dfaf866dd2bd48 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 14 May 2025 11:23:52 +0200 Subject: [PATCH 13/68] hubert --- .../models/hubert/modeling_hubert.py | 476 ++++++------------ .../models/hubert/modular_hubert.py | 1 + 2 files changed, 148 insertions(+), 329 deletions(-) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 575239e4d82e..092b1c55e1eb 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -5,7 +5,7 @@ # modular_hubert.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import warnings -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -15,15 +15,17 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_hubert import HubertConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -224,6 +226,33 @@ def forward(self, hidden_states): return hidden_states +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class HubertAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -258,6 +287,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -269,6 +299,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -276,10 +309,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -294,18 +331,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -317,303 +354,43 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class HubertFlashAttention2(HubertAttention): - """ - Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # HubertFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("HubertFlashAttention2 attention does not support output_attentions") - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class HubertSdpaAttention(HubertAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "HubertModel is using HubertSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - class HubertFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -638,21 +415,15 @@ def forward(self, hidden_states): return hidden_states -HUBERT_ATTENTION_CLASSES = { - "eager": HubertAttention, - "sdpa": HubertSdpaAttention, - "flash_attention_2": HubertFlashAttention2, -} - - class HubertEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = HubertAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -689,7 +460,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -706,16 +476,13 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + + _unsupported_features = output_attentions is True + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + _unsupported_features, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -764,6 +531,33 @@ def forward( attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class HubertAttnAdapterLayer(nn.Module): def __init__(self, config): @@ -793,11 +587,12 @@ def forward(self, hidden_states: torch.FloatTensor): class HubertEncoderLayerStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() - self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = HubertAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -846,7 +641,6 @@ def __init__(self, config): [HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -860,19 +654,16 @@ def forward( all_self_attentions = () if output_attentions else None if attention_mask is not None: - # make sure padded tokens are not attended to + # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + hidden_states[~expand_attention_mask] = 0 + + _unsupported_features = output_attentions is True + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + _unsupported_features, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -923,6 +714,33 @@ def forward( attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + @auto_docstring class HubertPreTrainedModel(PreTrainedModel): diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py index b3e3d24cc0e8..5286f9065b58 100644 --- a/src/transformers/models/hubert/modular_hubert.py +++ b/src/transformers/models/hubert/modular_hubert.py @@ -115,6 +115,7 @@ class HubertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" From fc41dc2028dd1123dd0cb158ff8f06a1a2dbd7e9 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 14 May 2025 13:54:43 +0200 Subject: [PATCH 14/68] another batch of models --- .../models/hubert/modeling_hubert.py | 1 + .../models/m2m_100/modeling_m2m_100.py | 578 +++++--------- .../models/marian/modeling_marian.py | 318 +++++--- .../models/musicgen/modeling_musicgen.py | 537 +++++-------- .../modeling_musicgen_melody.py | 470 ++++------- .../models/nllb_moe/modeling_nllb_moe.py | 325 ++++++-- .../patchtsmixer/modeling_patchtsmixer.py | 152 ++-- .../models/patchtst/modeling_patchtst.py | 147 ++-- .../models/pegasus/modeling_pegasus.py | 318 +++++--- .../models/pegasus_x/modeling_pegasus_x.py | 271 +++++-- .../models/plbart/modeling_plbart.py | 754 ++++++++++-------- .../models/plbart/modular_plbart.py | 602 ++++++++++++++ 12 files changed, 2613 insertions(+), 1860 deletions(-) create mode 100644 src/transformers/models/plbart/modular_plbart.py diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 092b1c55e1eb..5fee0c483bff 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -750,6 +750,7 @@ class HubertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 3e01afd4329e..f47e45183cec 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -15,7 +15,7 @@ """PyTorch M2M100 model.""" import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -31,20 +31,23 @@ _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_m2m_100 import M2M100Config -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -173,6 +176,34 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_ return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->M2M100 class M2M100Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -208,6 +239,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -219,6 +251,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -226,10 +261,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -244,18 +283,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -267,312 +306,50 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->M2M100 -class M2M100FlashAttention2(M2M100Attention): - """ - M2M100 flash attention module. This module inherits from `M2M100Attention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # M2M100FlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("M2M100FlashAttention2 attention does not support output_attentions") - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->M2M100 -class M2M100SdpaAttention(M2M100Attention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "M2M100Model is using M2M100SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100 class M2M100EncoderLayer(nn.Module): def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model - self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = M2M100Attention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -635,20 +412,13 @@ def forward( return outputs -M2M100_ATTENTION_CLASSES = { - "eager": M2M100Attention, - "flash_attention_2": M2M100FlashAttention2, - "sdpa": M2M100SdpaAttention, -} - - # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100 class M2M100DecoderLayer(nn.Module): def __init__(self, config: M2M100Config): super().__init__() self.embed_dim = config.d_model - self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = M2M100Attention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -661,7 +431,7 @@ def __init__(self, config: M2M100Config): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = M2M100Attention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -771,6 +541,7 @@ class M2M100PreTrainedModel(PreTrainedModel): _no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): std = self.config.init_std @@ -819,8 +590,6 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = ) self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -899,18 +668,12 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - if self._use_flash_attention_2: - attention_mask = attention_mask if 0 in attention_mask else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + _unsupported_features, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -970,6 +733,34 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class M2M100Decoder(M2M100PreTrainedModel): """ @@ -1001,8 +792,6 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = self.padding_idx, ) self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1113,42 +902,21 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - combined_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - combined_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - combined_attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, + ) # embed positions positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) @@ -1198,7 +966,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - combined_attention_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1210,7 +978,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), @@ -1255,6 +1023,92 @@ def forward( cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class M2M100Model(M2M100PreTrainedModel): diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index f544641e0d65..2779259d543b 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -26,7 +26,13 @@ from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -34,11 +40,16 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_marian import MarianConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -90,6 +101,34 @@ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) return super().forward(positions) +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Marian class MarianAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -125,6 +164,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -136,6 +176,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -143,10 +186,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -161,18 +208,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -184,69 +231,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN @@ -255,7 +274,7 @@ def __init__(self, config: MarianConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = MarianAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -320,16 +339,13 @@ def forward( return outputs -MARIAN_ATTENTION_CLASSES = {"eager": MarianAttention} - - # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN class MarianDecoderLayer(nn.Module): def __init__(self, config: MarianConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = MarianAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -342,7 +358,7 @@ def __init__(self, config: MarianConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = MarianAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -449,6 +465,9 @@ class MarianPreTrainedModel(PreTrainedModel): config_class = MarianConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]): std = self.config.init_std @@ -588,10 +607,12 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + _unsupported_features, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -644,6 +665,34 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class MarianDecoder(MarianPreTrainedModel): """ @@ -786,16 +835,21 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) @@ -892,6 +946,92 @@ def forward( cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class MarianModel(MarianPreTrainedModel): diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index f906c1a93c23..7f66abd77a32 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -19,7 +19,7 @@ import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -40,7 +40,9 @@ _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -48,15 +50,17 @@ ModelOutput, Seq2SeqLMOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ..auto.configuration_auto import AutoConfig from ..auto.modeling_auto import AutoModel from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + if TYPE_CHECKING: from ...generation.streamers import BaseStreamer @@ -146,6 +150,34 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): return self.weights.index_select(0, position_ids.view(-1)).detach() +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Musicgen class MusicgenAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -181,6 +213,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -192,6 +225,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -199,10 +235,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -217,18 +257,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -240,334 +280,49 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Musicgen -class MusicgenFlashAttention2(MusicgenAttention): - """ - Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # MusicgenFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("MusicgenFlashAttention2 attention does not support output_attentions") - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class MusicgenSdpaAttention(MusicgenAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - if ( - attention_mask is not None - and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any() - ): - logger.warning_once( - '`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - "Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information." - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -MUSICGEN_ATTENTION_CLASSES = { - "eager": MusicgenAttention, - "sdpa": MusicgenSdpaAttention, - "flash_attention_2": MusicgenFlashAttention2, -} - - class MusicgenDecoderLayer(nn.Module): def __init__(self, config: MusicgenDecoderConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = MusicgenAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, dropout=config.attention_dropout, @@ -581,7 +336,7 @@ def __init__(self, config: MusicgenDecoderConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = MusicgenAttention( self.embed_dim, config.num_attention_heads, dropout=config.attention_dropout, @@ -693,6 +448,7 @@ class MusicgenPreTrainedModel(PreTrainedModel): _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): std = self.config.initializer_factor @@ -824,40 +580,21 @@ def forward( if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) - if self.attn_implementation == "flash_attention_2": - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - else: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.attn_implementation == "flash_attention_2": - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, + ) # embed positions positions = self.embed_positions(input, past_key_values_length) @@ -956,6 +693,92 @@ def forward( cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class MusicgenModel(MusicgenPreTrainedModel): diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 2b6930ff3a29..9eb3672e7f2a 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -19,7 +19,7 @@ import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -34,18 +34,25 @@ LogitsProcessorList, StoppingCriteriaList, ) -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, +) from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from ..auto.configuration_auto import AutoConfig from ..auto.modeling_auto import AutoModel, AutoModelForTextEncoding from .configuration_musicgen_melody import MusicgenMelodyConfig, MusicgenMelodyDecoderConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + if TYPE_CHECKING: from ...generation.streamers import BaseStreamer @@ -159,6 +166,34 @@ def forward(self, inputs_embeds: torch.Tensor, past_key_values_length: int = 0): return self.weights.index_select(0, position_ids.view(-1)).detach() +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MusicgenMelody class MusicgenMelodyAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -194,6 +229,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -205,6 +241,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -212,10 +251,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -230,18 +273,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -253,318 +296,49 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MusicgenMelody -class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention): - """ - MusicgenMelody flash attention module. This module inherits from `MusicgenMelodyAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # MusicgenMelodyFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("MusicgenMelodyFlashAttention2 attention does not support output_attentions") - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MusicgenMelody -class MusicgenMelodySdpaAttention(MusicgenMelodyAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "MusicgenMelodyModel is using MusicgenMelodySdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -MUSICGEN_MELODY_ATTENTION_CLASSES = { - "eager": MusicgenMelodyAttention, - "sdpa": MusicgenMelodySdpaAttention, - "flash_attention_2": MusicgenMelodyFlashAttention2, -} - - class MusicgenMelodyDecoderLayer(nn.Module): def __init__(self, config: MusicgenMelodyDecoderConfig): super().__init__() self.embed_dim = config.hidden_size - self.self_attn = MUSICGEN_MELODY_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = MusicgenMelodyAttention( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, dropout=config.attention_dropout, @@ -649,6 +423,7 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel): _no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): std = self.config.initializer_factor @@ -790,21 +565,14 @@ def forward( input_shape = inputs_embeds.size()[:-1] - if self.attn_implementation == "flash_attention_2": - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.attn_implementation == "sdpa" and not output_attentions: - # output_attentions=True can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - else: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) # embed positions positions = self.embed_positions(inputs_embeds, past_key_values_length) @@ -886,6 +654,64 @@ def forward( attentions=all_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Ignore copy + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # MusicgenMelody doesn't apply cross attention, hence it's ignored here + # and only exists to not confuse any copy checks + pass + @auto_docstring # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 79002aee3178..b227057d1d30 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -15,7 +15,7 @@ """PyTorch NLLB-MoE model.""" import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -25,18 +25,29 @@ from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( MoEModelOutput, MoEModelOutputWithPastAndCrossAttentions, Seq2SeqMoEModelOutput, Seq2SeqMoEOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_nllb_moe import NllbMoeConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -460,6 +471,34 @@ def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tens return hidden_states, (router_probs, top_1_expert_index) +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->NllbMoe,key_value_states->encoder_hidden_states class NllbMoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -495,6 +534,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -506,6 +546,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -513,10 +556,16 @@ def forward( # for the decoder is_cross_attention = encoder_hidden_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = ( + (*encoder_hidden_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + ) # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -531,18 +580,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz) + key_states = self.k_proj(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -554,69 +603,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class NllbMoeEncoderLayer(nn.Module): @@ -628,6 +649,7 @@ def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + config=config, ) self.attn_dropout = nn.Dropout(config.dropout) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -710,6 +732,7 @@ def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -717,7 +740,11 @@ def __init__(self, config: NllbMoeConfig, is_sparse: bool = False): self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.cross_attention = NllbMoeAttention( - self.embed_dim, config.decoder_attention_heads, config.attention_dropout, is_decoder=True + self.embed_dim, + config.decoder_attention_heads, + config.attention_dropout, + is_decoder=True, + config=config, ) self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim) if not self.is_sparse: @@ -837,6 +864,9 @@ class NllbMoePreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["NllbMoeEncoderLayer", "NllbMoeDecoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" @@ -975,10 +1005,12 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + _unsupported_features, + ) encoder_states = () if output_hidden_states else None all_router_probs = () if output_router_logits else None @@ -1042,6 +1074,34 @@ def forward( router_probs=all_router_probs, ) + # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class NllbMoeDecoder(NllbMoePreTrainedModel): """ @@ -1195,18 +1255,21 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) # embed positions positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) @@ -1264,7 +1327,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.forward, hidden_states, - combined_attention_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, layer_head_mask, @@ -1276,7 +1339,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=layer_head_mask, @@ -1330,6 +1393,92 @@ def forward( router_probs=all_router_probs, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class NllbMoeModel(NllbMoePreTrainedModel): diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 983961f3b4e2..a089725b6198 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -24,11 +24,18 @@ from transformers.modeling_utils import PreTrainedModel from transformers.utils import ModelOutput +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS +from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput -from ...utils import auto_docstring, logging +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_patchtsmixer import PatchTSMixerConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -234,6 +241,34 @@ def forward(self, inputs: torch.Tensor): return out +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTSMixer class PatchTSMixerAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -269,6 +304,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -280,6 +316,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -287,10 +326,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -305,18 +348,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -328,69 +371,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class PatchMixerBlock(nn.Module): @@ -423,6 +438,7 @@ def __init__(self, config: PatchTSMixerConfig): embed_dim=config.d_model, num_heads=config.self_attn_heads, dropout=config.dropout, + config=config, ) self.norm_attn = PatchTSMixerNormLayer(config) diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index e5785974be5b..45cf9893c2d2 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -16,14 +16,16 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch import nn from ...activations import ACT2CLS +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import ModelOutput, auto_docstring, logging from .configuration_patchtst import PatchTSTConfig @@ -32,6 +34,34 @@ logger = logging.get_logger(__name__) +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTST class PatchTSTAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -67,6 +97,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -78,6 +109,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -85,10 +119,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -103,18 +141,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -126,69 +164,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class PatchTSTBatchNorm(nn.Module): @@ -461,6 +471,7 @@ def __init__(self, config: PatchTSTConfig): embed_dim=config.d_model, num_heads=config.num_attention_heads, dropout=config.attention_dropout, + config=config, ) # Add & Norm of the sublayer 1 diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 9d0eafa9c94d..63508f5b3b0e 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -26,7 +26,13 @@ from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -34,11 +40,16 @@ Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_pegasus import PegasusConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -91,6 +102,34 @@ def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) return super().forward(positions) +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Pegasus class PegasusAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -126,6 +165,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -137,6 +177,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -144,10 +187,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -162,18 +209,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -185,72 +232,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value - - -PEGASUS_ATTENTION_CLASSES = {"eager": PegasusAttention} + return attn_output, attn_weights, past_key_value # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS @@ -259,7 +275,7 @@ def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = PegasusAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -328,7 +344,7 @@ def __init__(self, config: PegasusConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = PegasusAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -341,7 +357,7 @@ def __init__(self, config: PegasusConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = PegasusAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -448,6 +464,9 @@ class PegasusPreTrainedModel(PreTrainedModel): config_class = PegasusConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): std = self.config.init_std @@ -604,10 +623,12 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + _unsupported_features, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -664,6 +685,34 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class PegasusDecoder(PegasusPreTrainedModel): """ @@ -839,16 +888,21 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) # embed positions positions = self.embed_positions(input_shape, past_key_values_length) @@ -948,6 +1002,92 @@ def forward( cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class PegasusModel(PegasusPreTrainedModel): diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 9656b3afe77f..3bdafa78488a 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -16,7 +16,7 @@ import dataclasses import math -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -26,18 +26,29 @@ from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_pegasus_x import PegasusXConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -116,6 +127,34 @@ def forward(self, input_embeds: torch.Tensor, past_key_values_length: int = 0) - return pe[None].expand(batch_size, -1, -1) +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PegasusX class PegasusXAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -151,6 +190,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -162,6 +202,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -169,10 +212,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -187,18 +234,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -210,69 +257,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value class PegasusXGlobalLocalAttention(nn.Module): @@ -635,6 +654,7 @@ def __init__(self, config: PegasusXConfig): dropout=config.attention_dropout, is_decoder=True, bias=False, + config=config, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -647,6 +667,7 @@ def __init__(self, config: PegasusXConfig): dropout=config.attention_dropout, is_decoder=True, bias=False, + config=config, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -742,6 +763,9 @@ class PegasusXPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): std = self.config.init_std @@ -1094,16 +1118,21 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + _unsupported_features = output_attentions is True + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) # embed positions positions = self.embed_positions(inputs_embeds, past_key_values_length) @@ -1191,6 +1220,92 @@ def forward( cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class PegasusXModel(PegasusXPreTrainedModel): diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 22157b29c775..b4955c33865b 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1,25 +1,14 @@ -# coding=utf-8 -# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch PLBART model.""" - +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/plbart/modular_plbart.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_plbart.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch -import torch.utils.checkpoint from torch import nn from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss @@ -31,6 +20,7 @@ _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa, ) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -39,36 +29,54 @@ Seq2SeqModelOutput, Seq2SeqSequenceClassifierOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_plbart import PLBartConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) -# Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right -def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): +class PLBartScaledWordEmbedding(nn.Embedding): """ - Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not - have a single `decoder_start_token_id` in contrast to other Bart-like models. + This module overrides nn.Embeddings' forward by multiplying with embeddings scale. """ - prev_output_tokens = input_ids.clone() - if pad_token_id is None: - raise ValueError("self.model.config.pad_token_id has to be defined.") - # replace possible -100 values in labels by `pad_token_id` - prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) + def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): + super().__init__(num_embeddings, embedding_dim, padding_idx) + self.embed_scale = embed_scale - index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) - decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() - prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() - prev_output_tokens[:, 0] = decoder_start_tokens + def forward(self, input_ids: torch.Tensor): + return super().forward(input_ids) * self.embed_scale - return prev_output_tokens + +@auto_docstring +class PLBartPreTrainedModel(PreTrainedModel): + config_class = PLBartConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() -# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart class PLBartLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. @@ -91,21 +99,33 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): return super().forward(positions + self.offset) -# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PLBart -class PLBartScaledWordEmbedding(nn.Embedding): - """ - This module overrides nn.Embeddings' forward by multiplying with embeddings scale. - """ +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask - def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0): - super().__init__(num_embeddings, embedding_dim, padding_idx) - self.embed_scale = embed_scale + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) - def forward(self, input_ids: torch.Tensor): - return super().forward(input_ids) * self.embed_scale + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PLBart class PLBartAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -140,6 +160,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -151,6 +172,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -158,10 +182,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -176,18 +204,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -199,78 +227,49 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->PLBart, BART->PLBART class PLBartEncoderLayer(nn.Module): def __init__(self, config: PLBartConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = PLBartAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -335,176 +334,6 @@ def forward( return outputs -# TODO: Implement attention with SDPA for PLBart. -PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention} - - -# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART -class PLBartDecoderLayer(nn.Module): - def __init__(self, config: PLBartConfig): - super().__init__() - self.embed_dim = config.d_model - - self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation]( - embed_dim=self.embed_dim, - num_heads=config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - is_causal=True, - config=config, - ) - self.dropout = config.dropout - self.activation_fn = ACT2FN[config.activation_function] - self.activation_dropout = config.activation_dropout - - self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation]( - self.embed_dim, - config.decoder_attention_heads, - dropout=config.attention_dropout, - is_decoder=True, - config=config, - ) - self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) - self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) - self.final_layer_norm = nn.LayerNorm(self.embed_dim) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - encoder_hidden_states: Optional[torch.Tensor] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - output_attentions: Optional[bool] = False, - use_cache: Optional[bool] = True, - ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: - """ - Args: - hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` - attention_mask (`torch.FloatTensor`): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - encoder_hidden_states (`torch.FloatTensor`): - cross attention input to the layer of shape `(batch, seq_len, embed_dim)` - encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. - layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size - `(encoder_attention_heads,)`. - cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of - size `(decoder_attention_heads,)`. - past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states - output_attentions (`bool`, *optional*): - Whether or not to return the attentions tensors of all attention layers. See `attentions` under - returned tensors for more detail. - """ - residual = hidden_states - - # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( - hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.self_attn_layer_norm(hidden_states) - - # Cross-Attention Block - cross_attn_present_key_value = None - cross_attn_weights = None - if encoder_hidden_states is not None: - residual = hidden_states - - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( - hidden_states=hidden_states, - key_value_states=encoder_hidden_states, - attention_mask=encoder_attention_mask, - layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - output_attentions=output_attentions, - ) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.encoder_attn_layer_norm(hidden_states) - - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - - # Fully Connected - residual = hidden_states - hidden_states = self.activation_fn(self.fc1(hidden_states)) - hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) - hidden_states = self.fc2(hidden_states) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states - hidden_states = self.final_layer_norm(hidden_states) - - outputs = (hidden_states,) - - if output_attentions: - outputs += (self_attn_weights, cross_attn_weights) - - if use_cache: - outputs += (present_key_value,) - - return outputs - - -# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->PLBart -class PLBartClassificationHead(nn.Module): - """Head for sentence-level classification tasks.""" - - def __init__( - self, - input_dim: int, - inner_dim: int, - num_classes: int, - pooler_dropout: float, - ): - super().__init__() - self.dense = nn.Linear(input_dim, inner_dim) - self.dropout = nn.Dropout(p=pooler_dropout) - self.out_proj = nn.Linear(inner_dim, num_classes) - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dropout(hidden_states) - hidden_states = self.dense(hidden_states) - hidden_states = torch.tanh(hidden_states) - hidden_states = self.dropout(hidden_states) - hidden_states = self.out_proj(hidden_states) - return hidden_states - - -@auto_docstring -class PLBartPreTrainedModel(PreTrainedModel): - config_class = PLBartConfig - base_model_prefix = "model" - supports_gradient_checkpointing = True - _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] - - def _init_weights(self, module): - std = self.config.init_std - if isinstance(module, nn.Linear): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() - - -# Copied from transformers.models.bart.modeling_bart.BartEncoder with Bart->PLBart class PLBartEncoder(PLBartPreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a @@ -538,8 +367,6 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = embed_dim, ) self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False @@ -625,18 +452,12 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - if self._use_flash_attention_2: - attention_mask = attention_mask if 0 in attention_mask else None - elif self._use_sdpa and head_mask is None and not output_attentions: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + _unsupported_features, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -692,8 +513,154 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + +class PLBartDecoderLayer(nn.Module): + def __init__(self, config: PLBartConfig): + super().__init__() + self.embed_dim = config.d_model + + self.self_attn = PLBartAttention( + embed_dim=self.embed_dim, + num_heads=config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + is_causal=True, + config=config, + ) + self.dropout = config.dropout + self.activation_fn = ACT2FN[config.activation_function] + self.activation_dropout = config.activation_dropout + + self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.encoder_attn = PLBartAttention( + self.embed_dim, + config.decoder_attention_heads, + dropout=config.attention_dropout, + is_decoder=True, + config=config, + ) + self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) + self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) + self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) + self.final_layer_norm = nn.LayerNorm(self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + cross_attn_layer_head_mask: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = True, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + encoder_hidden_states (`torch.FloatTensor`): + cross attention input to the layer of shape `(batch, seq_len, embed_dim)` + encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size + `(encoder_attention_heads,)`. + cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of + size `(decoder_attention_heads,)`. + past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + # Self Attention + # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 + self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None + # add present self-attn cache to positions 1,2 of present_key_value tuple + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + past_key_value=self_attn_past_key_value, + attention_mask=attention_mask, + layer_head_mask=layer_head_mask, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.self_attn_layer_norm(hidden_states) + + # Cross-Attention Block + cross_attn_present_key_value = None + cross_attn_weights = None + if encoder_hidden_states is not None: + residual = hidden_states + + # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple + cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None + hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states=hidden_states, + key_value_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + layer_head_mask=cross_attn_layer_head_mask, + past_key_value=cross_attn_past_key_value, + output_attentions=output_attentions, + ) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.encoder_attn_layer_norm(hidden_states) + + # add cross-attn to positions 3,4 of present_key_value tuple + present_key_value = present_key_value + cross_attn_present_key_value + + # Fully Connected + residual = hidden_states + hidden_states = self.activation_fn(self.fc1(hidden_states)) + hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) + hidden_states = self.fc2(hidden_states) + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = residual + hidden_states + hidden_states = self.final_layer_norm(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights, cross_attn_weights) + + if use_cache: + outputs += (present_key_value,) + + return outputs + -# Copied from transformers.models.bart.modeling_bart.BartDecoder with Bart->PLBart class PLBartDecoder(PLBartPreTrainedModel): """ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PLBartDecoderLayer`] @@ -723,8 +690,6 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = config.d_model, ) self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)]) - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" - self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -844,42 +809,21 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self._use_flash_attention_2: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, + ) # embed positions positions = self.embed_positions(input, past_key_values_length) @@ -980,6 +924,111 @@ def forward( cross_attentions=all_cross_attentions, ) + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + +# Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): + """ + Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not + have a single `decoder_start_token_id` in contrast to other Bart-like models. + """ + prev_output_tokens = input_ids.clone() + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) + + index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) + decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() + prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() + prev_output_tokens[:, 0] = decoder_start_tokens + + return prev_output_tokens + @auto_docstring class PLBartModel(PLBartPreTrainedModel): @@ -1307,6 +1356,30 @@ def _reorder_cache(past_key_values, beam_idx): return reordered_past +class PLBartClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__( + self, + input_dim: int, + inner_dim: int, + num_classes: int, + pooler_dropout: float, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.dropout(hidden_states) + hidden_states = self.dense(hidden_states) + hidden_states = torch.tanh(hidden_states) + hidden_states = self.dropout(hidden_states) + hidden_states = self.out_proj(hidden_states) + return hidden_states + + @auto_docstring( custom_intro=""" PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for code @@ -1457,7 +1530,6 @@ def forward( ) -# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->PLBart class PLBartDecoderWrapper(PLBartPreTrainedModel): """ This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is @@ -1472,7 +1544,11 @@ def forward(self, *args, **kwargs): return self.decoder(*args, **kwargs) -# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base +@auto_docstring( + custom_intro=""" + PLBART decoder with a language modeling head on top (linear layer with weights tied to the input embeddings). + """ +) class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin): _tied_weights_keys = ["lm_head.weight"] diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py new file mode 100644 index 000000000000..f7bc3df027ca --- /dev/null +++ b/src/transformers/models/plbart/modular_plbart.py @@ -0,0 +1,602 @@ +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss + +from ...generation import GenerationMixin +from ...modeling_outputs import ( + BaseModelOutput, + Seq2SeqLMOutput, + Seq2SeqModelOutput, + Seq2SeqSequenceClassifierOutput, +) +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring +from ..bart.modeling_bart import ( + BartClassificationHead, + BartDecoder, + BartEncoder, + BartForCausalLM, + BartScaledWordEmbedding, +) +from .configuration_plbart import PLBartConfig + + +# Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right +def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): + """ + Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not + have a single `decoder_start_token_id` in contrast to other Bart-like models. + """ + prev_output_tokens = input_ids.clone() + + if pad_token_id is None: + raise ValueError("self.model.config.pad_token_id has to be defined.") + # replace possible -100 values in labels by `pad_token_id` + prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) + + index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) + decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() + prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() + prev_output_tokens[:, 0] = decoder_start_tokens + + return prev_output_tokens + + +class PLBartScaledWordEmbedding(BartScaledWordEmbedding): + pass + + +@auto_docstring +class PLBartPreTrainedModel(PreTrainedModel): + config_class = PLBartConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True + + def _init_weights(self, module): + std = self.config.init_std + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +class PLBartEncoder(BartEncoder): + pass + + +class PLBartDecoder(BartDecoder): + pass + + +@auto_docstring +class PLBartModel(PLBartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: PLBartConfig): + super().__init__(config) + + padding_idx, vocab_size = config.pad_token_id, config.vocab_size + embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0 + self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale) + + self.encoder = PLBartEncoder(config, self.shared) + self.decoder = PLBartDecoder(config, self.shared) + + self.init_weights() + + def get_input_embeddings(self): + return self.shared + + def set_input_embeddings(self, value): + self.shared = value + self.encoder.embed_tokens = self.shared + self.decoder.embed_tokens = self.shared + + def _tie_weights(self): + if self.config.tie_word_embeddings: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + + def get_encoder(self): + return self.encoder + + def get_decoder(self): + return self.decoder + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.LongTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # different to other models, PLBart automatically creates decoder_input_ids from + # input_ids if no decoder_input_ids are provided + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id) + + if encoder_outputs is None: + encoder_outputs = self.encoder( + input_ids=input_ids, + attention_mask=attention_mask, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True + elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): + encoder_outputs = BaseModelOutput( + last_hidden_state=encoder_outputs[0], + hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None, + attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, + ) + + # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn) + decoder_outputs = self.decoder( + input_ids=decoder_input_ids, + attention_mask=decoder_attention_mask, + encoder_hidden_states=encoder_outputs[0], + encoder_attention_mask=attention_mask, + head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + if not return_dict: + return decoder_outputs + encoder_outputs + + return Seq2SeqModelOutput( + last_hidden_state=decoder_outputs.last_hidden_state, + past_key_values=decoder_outputs.past_key_values, + decoder_hidden_states=decoder_outputs.hidden_states, + decoder_attentions=decoder_outputs.attentions, + cross_attentions=decoder_outputs.cross_attentions, + encoder_last_hidden_state=encoder_outputs.last_hidden_state, + encoder_hidden_states=encoder_outputs.hidden_states, + encoder_attentions=encoder_outputs.attentions, + ) + + +@auto_docstring( + custom_intro=""" + The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code. + """ +) +class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin): + base_model_prefix = "model" + _keys_to_ignore_on_load_missing = ["final_logits_bias"] + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"] + + def __init__(self, config: PLBartConfig): + super().__init__(config) + self.model = PLBartModel(config) + self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings))) + self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False) + + self.init_weights() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + def resize_token_embeddings( + self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True + ) -> nn.Embedding: + new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing) + self._resize_final_logits_bias(new_embeddings.weight.shape[0]) + return new_embeddings + + def _resize_final_logits_bias(self, new_num_tokens: int) -> None: + old_num_tokens = self.final_logits_bias.shape[-1] + if new_num_tokens <= old_num_tokens: + new_bias = self.final_logits_bias[:, :new_num_tokens] + else: + extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device) + new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1) + self.register_buffer("final_logits_bias", new_bias) + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @auto_docstring + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.LongTensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example Mask-filling: + + ```python + >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration + + >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base") + >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base") + + >>> # en_XX is the language symbol id for English + >>> TXT = " Is 0 the Fibonacci number ? en_XX" + >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids + + >>> logits = model(input_ids).logits + >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item() + >>> probs = logits[0, masked_index].softmax(dim=0) + >>> values, predictions = probs.topk(5) + + >>> tokenizer.decode(predictions).split() + ['first', 'same', 'highest', 'result', 'number'] + ``` + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if labels is not None: + if decoder_input_ids is None and decoder_inputs_embeds is None: + decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + encoder_outputs=encoder_outputs, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + lm_logits = self.lm_head(outputs[0]) + lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) + + masked_lm_loss = None + if labels is not None: + loss_fct = CrossEntropyLoss() + masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1)) + + if not return_dict: + output = (lm_logits,) + outputs[1:] + return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output + + return Seq2SeqLMOutput( + loss=masked_lm_loss, + logits=lm_logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return shift_tokens_right(labels, self.config.pad_token_id) + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + # cached cross_attention states don't have to be reordered -> they are always the same + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2]) + + layer_past[2:], + ) + return reordered_past + + +class PLBartClassificationHead(BartClassificationHead): + pass + + +@auto_docstring( + custom_intro=""" + PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for code + classification. + """ +) +class PLBartForSequenceClassification(PLBartPreTrainedModel): + _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] + + def __init__(self, config: PLBartConfig, **kwargs): + super().__init__(config, **kwargs) + self.model = PLBartModel(config) + self.classification_head = PLBartClassificationHead( + config.d_model, + config.d_model, + config.num_labels, + config.classifier_dropout, + ) + + # Initialize weights and apply final processing + self.post_init() + + @auto_docstring + # Ignore copy + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + decoder_input_ids: Optional[torch.LongTensor] = None, + decoder_attention_mask: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.Tensor] = None, + decoder_head_mask: Optional[torch.Tensor] = None, + cross_attn_head_mask: Optional[torch.Tensor] = None, + encoder_outputs: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + decoder_inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: + r""" + decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): + Indices of decoder input sequence tokens in the vocabulary. + + Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint. + See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. + + [What are decoder input IDs?](../glossary#decoder-input-ids) + + PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that + varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If + `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + For translation and summarization training, `decoder_input_ids` should be provided. If no + `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right + for denoising pre-training following the paper. + decoder_attention_mask (: + obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*): + Default behavior: + generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default. + cross_attn_head_mask (: + obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify + selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if labels is not None: + use_cache = False + + if input_ids is None and inputs_embeds is not None: + raise NotImplementedError( + f"Passing input embeddings is currently not supported for {self.__class__.__name__}" + ) + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = outputs[0] # last hidden state + + eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) + + if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ + :, -1, : + ] + logits = self.classification_head(sentence_representation) + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.config.num_labels == 1: + self.config.problem_type = "regression" + elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.config.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[1:] + return ((loss,) + output) if loss is not None else output + + return Seq2SeqSequenceClassifierOutput( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + decoder_hidden_states=outputs.decoder_hidden_states, + decoder_attentions=outputs.decoder_attentions, + cross_attentions=outputs.cross_attentions, + encoder_last_hidden_state=outputs.encoder_last_hidden_state, + encoder_hidden_states=outputs.encoder_hidden_states, + encoder_attentions=outputs.encoder_attentions, + ) + + +class PLBartForCausalLM(BartForCausalLM): + @auto_docstring + def forward(**super_kwargs): + r""" + cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): + Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`: + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, PLBartForCausalLM + + >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base") + >>> model = PLBartForCausalLM.from_pretrained("uclanlp/plbart-base", add_cross_attention=False) + >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." + >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") + >>> outputs = model(**inputs) + + >>> logits = outputs.logits + >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size] + >>> list(logits.shape) == expected_shape + True + ```""" + super().forward(**super_kwargs) + + +__all__ = [ + "PLBartForCausalLM", + "PLBartForConditionalGeneration", + "PLBartForSequenceClassification", + "PLBartModel", + "PLBartPreTrainedModel", +] From 1e2b4f02ff9b6b597cb3276641e909dc5c3a201f Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 14 May 2025 16:40:06 +0200 Subject: [PATCH 15/68] copies/style + last round of bart models --> whisper next? --- .../models/informer/modeling_informer.py | 14 +- .../models/informer/modular_informer.py | 8 + .../models/mbart/modeling_mbart.py | 1 - .../patchtsmixer/modeling_patchtsmixer.py | 2 +- .../models/plbart/modeling_plbart.py | 15 + .../models/plbart/modular_plbart.py | 16 + .../qwen2_5_omni/modeling_qwen2_5_omni.py | 2 + .../qwen2_audio/modeling_qwen2_audio.py | 3 +- .../models/sew/feature_extractor_sew.py | 20 + src/transformers/models/sew/modeling_sew.py | 671 ++++++------------ src/transformers/models/sew/modular_sew.py | 572 +++++++++++++++ .../speech_to_text/modeling_speech_to_text.py | 317 ++++++--- .../models/unispeech/modeling_unispeech.py | 477 ++++--------- .../models/unispeech/modular_unispeech.py | 1 + .../unispeech_sat/modeling_unispeech_sat.py | 477 ++++--------- .../unispeech_sat/modular_unispeech_sat.py | 1 + .../modeling_wav2vec2_conformer.py | 6 +- .../models/wavlm/modeling_wavlm.py | 1 + .../models/wavlm/modular_wavlm.py | 2 + .../models/whisper/modeling_whisper.py | 8 +- 20 files changed, 1385 insertions(+), 1229 deletions(-) create mode 100644 src/transformers/models/sew/feature_extractor_sew.py create mode 100644 src/transformers/models/sew/modular_sew.py diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 233874521de7..5247c2b90a6d 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1692,13 +1692,6 @@ def forward( ) -def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: - """ - Computes the negative log likelihood loss from input distribution with respect to target. - """ - return -input.log_prob(target) - - def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor: """ Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero, @@ -1723,6 +1716,13 @@ def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] return input_tensor.mean(dim=dim) +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + @auto_docstring class InformerForPrediction(InformerPreTrainedModel): def __init__(self, config: InformerConfig): diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index edd12c192dc8..80b8f20966d5 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -31,6 +31,13 @@ from .configuration_informer import InformerConfig +def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: + """ + Computes the negative log likelihood loss from input distribution with respect to target. + """ + return -input.log_prob(target) + + class InformerFeatureEmbedder(TimeSeriesFeatureEmbedder): pass @@ -79,6 +86,7 @@ def _init_weights(self, module): class InformerAttention(BartAttention): pass + class InformerProbSparseAttention(nn.Module): """Probabilistic Attention mechanism to select the "active" queries rather than the "lazy" queries and provides a sparse Transformer thus mitigating the quadratic compute and diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 013fd6883df7..758de3695197 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -278,7 +278,6 @@ def forward( return attn_output, attn_weights, past_key_value - class MBartEncoderLayer(nn.Module): def __init__(self, config: MBartConfig): super().__init__() diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index a089725b6198..93ee81b727e0 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -33,7 +33,7 @@ if is_torch_flex_attn_available(): - from ...integrations.flex_attention import make_flex_block_causal_mask + pass logger = logging.get_logger(__name__) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index b4955c33865b..7fa87b6fb381 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -4,6 +4,21 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_plbart.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import copy import math from typing import Callable, List, Optional, Tuple, Union diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index f7bc3df027ca..27152541d000 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch PLBART model.""" + import math from typing import List, Optional, Tuple, Union diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index 4849fe4e2094..d0bbe6c7a5d3 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -736,6 +736,8 @@ def forward( } +# (BC Dep) Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen25OmniAudio, WHISPER->Qwen25OmniAudio +# TODO(vasqu): fix copies when enabling whisper attn interface class Qwen2_5OmniAudioEncoderLayer(nn.Module): def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): super().__init__() diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index b71fe6181365..f9cbb6b9b8ad 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -362,7 +362,8 @@ def forward( } -# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO +# (BC Dep) Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO +# TODO(vasqu): fix copies when enabling whisper attn interface class Qwen2AudioEncoderLayer(nn.Module): def __init__(self, config: Qwen2AudioConfig): super().__init__() diff --git a/src/transformers/models/sew/feature_extractor_sew.py b/src/transformers/models/sew/feature_extractor_sew.py new file mode 100644 index 000000000000..b65a2e8f0cbd --- /dev/null +++ b/src/transformers/models/sew/feature_extractor_sew.py @@ -0,0 +1,20 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sew/modular_sew.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sew.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +import warnings + +from .modeling_sew import SEWFeatureEncoder + + +class SEWFeatureExtractor(SEWFeatureEncoder): + def __init__(self, config): + super().__init__(config) + warnings.warn( + f"The class `{self.__class__.__name__}` has been depreciated " + "and will be removed in Transformers v5. " + f"Use `{self.__class__.__bases__[0].__name__}` instead.", + FutureWarning, + ) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 8bfeca060b49..76eab46d3a0f 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -1,170 +1,32 @@ -# coding=utf-8 -# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""PyTorch SEW model.""" - +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/sew/modular_sew.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_sew.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 import math import warnings -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch -import torch.utils.checkpoint from torch import nn from torch.nn import CrossEntropyLoss from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import auto_docstring, logging from .configuration_sew import SEWConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward - - logger = logging.get_logger(__name__) -_HIDDEN_STATES_START_POSITION = 1 - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices -def _compute_mask_indices( - shape: Tuple[int, int], - mask_prob: float, - mask_length: int, - attention_mask: Optional[torch.LongTensor] = None, - min_masks: int = 0, -) -> np.ndarray: - """ - Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for - ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on - CPU as part of the preprocessing during training. - - Args: - shape: The shape for which to compute masks. This should be of a tuple of size 2 where - the first element is the batch size and the second element is the length of the axis to span. - mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of - independently generated mask spans of length `mask_length` is computed by - `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the - actual percentage will be smaller. - mask_length: size of the mask - min_masks: minimum number of masked spans - attention_mask: A (right-padded) attention mask which independently shortens the feature axis of - each batch dimension. - """ - batch_size, sequence_length = shape - - if mask_length < 1: - raise ValueError("`mask_length` has to be bigger than 0.") - - if mask_length > sequence_length: - raise ValueError( - f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" - f" and `sequence_length`: {sequence_length}`" - ) - - # epsilon is used for probabilistic rounding - epsilon = np.random.rand(1).item() - - def compute_num_masked_span(input_length): - """Given input length, compute how many spans should be masked""" - num_masked_span = int(mask_prob * input_length / mask_length + epsilon) - num_masked_span = max(num_masked_span, min_masks) - - # make sure num masked span <= sequence_length - if num_masked_span * mask_length > sequence_length: - num_masked_span = sequence_length // mask_length - - # make sure num_masked span is also <= input_length - (mask_length - 1) - if input_length - (mask_length - 1) < num_masked_span: - num_masked_span = max(input_length - (mask_length - 1), 0) - - return num_masked_span - - # compute number of masked spans in batch - input_lengths = ( - attention_mask.detach().sum(-1).tolist() - if attention_mask is not None - else [sequence_length for _ in range(batch_size)] - ) - - # SpecAugment mask to fill - spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) - spec_aug_mask_idxs = [] - - max_num_masked_span = compute_num_masked_span(sequence_length) - - if max_num_masked_span == 0: - return spec_aug_mask - - for input_length in input_lengths: - # compute num of masked spans for this input - num_masked_span = compute_num_masked_span(input_length) - - # get random indices to mask - spec_aug_mask_idx = np.random.choice( - np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False - ) - - # pick first sampled index that will serve as a dummy index to pad vector - # to ensure same dimension for all batches due to probabilistic rounding - # Picking first sample just pads those vectors twice. - if len(spec_aug_mask_idx) == 0: - # this case can only happen if `input_length` is strictly smaller then - # `sequence_length` in which case the last token has to be a padding - # token which we can use as a dummy mask id - dummy_mask_idx = sequence_length - 1 - else: - dummy_mask_idx = spec_aug_mask_idx[0] - - spec_aug_mask_idx = np.concatenate( - [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] - ) - spec_aug_mask_idxs.append(spec_aug_mask_idx) - - spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) - - # expand masked indices to masked spans - spec_aug_mask_idxs = np.broadcast_to( - spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) - ) - spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) - - # add offset to the starting indexes so that indexes now create a span - offsets = np.arange(mask_length)[None, None, :] - offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( - batch_size, max_num_masked_span * mask_length - ) - spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - - # ensure that we cannot have indices larger than sequence_length - if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 - - # scatter indices to mask - np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) - - return spec_aug_mask - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SEW class SEWNoLayerNormConvLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -186,7 +48,6 @@ def forward(self, hidden_states): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SEW class SEWLayerNormConvLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -214,7 +75,6 @@ def forward(self, hidden_states): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SEW class SEWGroupNormConvLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -282,7 +142,6 @@ def forward(self, hidden_states): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SEW class SEWSamePadLayer(nn.Module): def __init__(self, num_conv_pos_embeddings): super().__init__() @@ -316,7 +175,6 @@ def forward(self, hidden_states): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SEW class SEWFeatureEncoder(nn.Module): """Construct the features from raw audio waveform""" @@ -361,18 +219,33 @@ def forward(self, input_values): return hidden_states -class SEWFeatureExtractor(SEWFeatureEncoder): - def __init__(self, config): - super().__init__(config) - warnings.warn( - f"The class `{self.__class__.__name__}` has been depreciated " - "and will be removed in Transformers v5. " - f"Use `{self.__class__.__bases__[0].__name__}` instead.", - FutureWarning, - ) +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->SEW class SEWAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -407,6 +280,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -418,6 +292,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -425,10 +302,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -443,18 +324,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -466,313 +347,43 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->SEW -class SEWFlashAttention2(SEWAttention): - """ - SEW flash attention module. This module inherits from `SEWAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # SEWFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("SEWFlashAttention2 attention does not support output_attentions") - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) - - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class SEWSdpaAttention(SEWAttention): - # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->SEW - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "SEWModel is using SEWSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - -SEW_ATTENTION_CLASSES = { - "eager": SEWAttention, - "sdpa": SEWSdpaAttention, - "flash_attention_2": SEWFlashAttention2, -} - - -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->SEW class SEWFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -797,15 +408,15 @@ def forward(self, hidden_states): return hidden_states -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->SEW, WAV2VEC2->SEW class SEWEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = SEW_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = SEWAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -844,7 +455,6 @@ def __init__(self, config): self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.upsample = SEWUpsampling(config) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -859,7 +469,7 @@ def forward( if attention_mask is not None: expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - if self._use_flash_attention_2: + if self.config._attn_implementation == "flash_attention_2": # make sure padded tokens output 0 hidden_states[~expand_attention_mask] = 0.0 # 2d mask is passed through the layers @@ -1013,6 +623,126 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti return attention_mask +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + @auto_docstring class SEWModel(SEWPreTrainedModel): def __init__(self, config: SEWConfig): @@ -1136,12 +866,14 @@ def forward( ) +_HIDDEN_STATES_START_POSITION = 1 + + @auto_docstring( custom_intro=""" SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC). """ ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW class SEWForCTC(SEWPreTrainedModel): def __init__(self, config, target_lang: Optional[str] = None): r""" @@ -1294,11 +1026,10 @@ def forward( @auto_docstring( custom_intro=""" - SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB - Keyword Spotting. + SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like + SUPERB Keyword Spotting. """ ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW class SEWForSequenceClassification(SEWPreTrainedModel): def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py new file mode 100644 index 000000000000..0f1ef81e7246 --- /dev/null +++ b/src/transformers/models/sew/modular_sew.py @@ -0,0 +1,572 @@ +import math +import warnings +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn + +from ...activations import ACT2FN +from ...integrations.deepspeed import is_deepspeed_zero3_enabled +from ...integrations.fsdp import is_fsdp_managed_module +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import PreTrainedModel +from ...utils import auto_docstring +from ..bart.modeling_bart import BartAttention +from ..wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2EncoderLayer, + Wav2Vec2FeatureEncoder, + Wav2Vec2FeedForward, + Wav2Vec2ForCTC, + Wav2Vec2ForSequenceClassification, + Wav2Vec2GroupNormConvLayer, + Wav2Vec2LayerNormConvLayer, + Wav2Vec2NoLayerNormConvLayer, + Wav2Vec2SamePadLayer, +) +from .configuration_sew import SEWConfig + + +_HIDDEN_STATES_START_POSITION = 1 + + +# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices +def _compute_mask_indices( + shape: Tuple[int, int], + mask_prob: float, + mask_length: int, + attention_mask: Optional[torch.LongTensor] = None, + min_masks: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for + ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on + CPU as part of the preprocessing during training. + + Args: + shape: The shape for which to compute masks. This should be of a tuple of size 2 where + the first element is the batch size and the second element is the length of the axis to span. + mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of + independently generated mask spans of length `mask_length` is computed by + `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the + actual percentage will be smaller. + mask_length: size of the mask + min_masks: minimum number of masked spans + attention_mask: A (right-padded) attention mask which independently shortens the feature axis of + each batch dimension. + """ + batch_size, sequence_length = shape + + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" + f" and `sequence_length`: {sequence_length}`" + ) + + # epsilon is used for probabilistic rounding + epsilon = np.random.rand(1).item() + + def compute_num_masked_span(input_length): + """Given input length, compute how many spans should be masked""" + num_masked_span = int(mask_prob * input_length / mask_length + epsilon) + num_masked_span = max(num_masked_span, min_masks) + + # make sure num masked span <= sequence_length + if num_masked_span * mask_length > sequence_length: + num_masked_span = sequence_length // mask_length + + # make sure num_masked span is also <= input_length - (mask_length - 1) + if input_length - (mask_length - 1) < num_masked_span: + num_masked_span = max(input_length - (mask_length - 1), 0) + + return num_masked_span + + # compute number of masked spans in batch + input_lengths = ( + attention_mask.detach().sum(-1).tolist() + if attention_mask is not None + else [sequence_length for _ in range(batch_size)] + ) + + # SpecAugment mask to fill + spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) + spec_aug_mask_idxs = [] + + max_num_masked_span = compute_num_masked_span(sequence_length) + + if max_num_masked_span == 0: + return spec_aug_mask + + for input_length in input_lengths: + # compute num of masked spans for this input + num_masked_span = compute_num_masked_span(input_length) + + # get random indices to mask + spec_aug_mask_idx = np.random.choice( + np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False + ) + + # pick first sampled index that will serve as a dummy index to pad vector + # to ensure same dimension for all batches due to probabilistic rounding + # Picking first sample just pads those vectors twice. + if len(spec_aug_mask_idx) == 0: + # this case can only happen if `input_length` is strictly smaller then + # `sequence_length` in which case the last token has to be a padding + # token which we can use as a dummy mask id + dummy_mask_idx = sequence_length - 1 + else: + dummy_mask_idx = spec_aug_mask_idx[0] + + spec_aug_mask_idx = np.concatenate( + [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] + ) + spec_aug_mask_idxs.append(spec_aug_mask_idx) + + spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) + + # expand masked indices to masked spans + spec_aug_mask_idxs = np.broadcast_to( + spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) + + # add offset to the starting indexes so that indexes now create a span + offsets = np.arange(mask_length)[None, None, :] + offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( + batch_size, max_num_masked_span * mask_length + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets + + # ensure that we cannot have indices larger than sequence_length + if spec_aug_mask_idxs.max() > sequence_length - 1: + spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 + + # scatter indices to mask + np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) + + return spec_aug_mask + + +class SEWNoLayerNormConvLayer(Wav2Vec2NoLayerNormConvLayer): + pass + + +class SEWLayerNormConvLayer(Wav2Vec2LayerNormConvLayer): + pass + + +class SEWGroupNormConvLayer(Wav2Vec2GroupNormConvLayer): + pass + + +class SEWPositionalConvEmbedding(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.hidden_size, + config.hidden_size, + kernel_size=config.num_conv_pos_embeddings, + padding=config.num_conv_pos_embeddings // 2, + groups=config.num_conv_pos_embedding_groups, + stride=config.squeeze_factor, + ) + + weight_norm = nn.utils.weight_norm + if hasattr(nn.utils.parametrizations, "weight_norm"): + weight_norm = nn.utils.parametrizations.weight_norm + + if is_deepspeed_zero3_enabled(): + import deepspeed + + with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0): + self.conv = weight_norm(self.conv, name="weight", dim=2) + if hasattr(self.conv, "parametrizations"): + weight_g = self.conv.parametrizations.weight.original0 + weight_v = self.conv.parametrizations.weight.original1 + else: + weight_g = self.conv.weight_g + weight_v = self.conv.weight_v + deepspeed.zero.register_external_parameter(self, weight_v) + deepspeed.zero.register_external_parameter(self, weight_g) + else: + self.conv = weight_norm(self.conv, name="weight", dim=2) + + self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings) + self.activation = ACT2FN[config.feat_extract_activation] + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = self.padding(hidden_states) + hidden_states = self.activation(hidden_states) + + return hidden_states + + +class SEWSamePadLayer(Wav2Vec2SamePadLayer): + pass + + +class SEWUpsampling(nn.Module): + def __init__(self, config): + super().__init__() + self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor) + self.activation = ACT2FN[config.feat_extract_activation] + self.squeeze_factor = config.squeeze_factor + + def forward(self, hidden_states): + hidden_states = self.projection(hidden_states) + hidden_states = self.activation(hidden_states) + + if self.squeeze_factor > 1: + # transform embedding channels to sequence length + bsz, src_len, src_embed_dim = hidden_states.size() + tgt_len = src_len * self.squeeze_factor + tgt_embed_dim = src_embed_dim // self.squeeze_factor + hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim) + hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim) + + return hidden_states + + +class SEWFeatureEncoder(Wav2Vec2FeatureEncoder): + pass + + +class SEWFeatureExtractor(SEWFeatureEncoder): + def __init__(self, config): + super().__init__(config) + warnings.warn( + f"The class `{self.__class__.__name__}` has been depreciated " + "and will be removed in Transformers v5. " + f"Use `{self.__class__.__bases__[0].__name__}` instead.", + FutureWarning, + ) + + +class SEWAttention(BartAttention): + pass + + +class SEWFeedForward(Wav2Vec2FeedForward): + pass + + +class SEWEncoderLayer(Wav2Vec2EncoderLayer): + pass + + +class SEWEncoder(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.pos_conv_embed = SEWPositionalConvEmbedding(config) + self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor) + self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.dropout = nn.Dropout(config.hidden_dropout) + self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.upsample = SEWUpsampling(config) + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + attention_mask=None, + output_attentions=False, + output_hidden_states=False, + return_dict=True, + ): + all_hidden_states = () if output_hidden_states else None + all_self_attentions = () if output_attentions else None + + if attention_mask is not None: + expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) + if self.config._attn_implementation == "flash_attention_2": + # make sure padded tokens output 0 + hidden_states[~expand_attention_mask] = 0.0 + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + # make sure padded tokens output 0 + hidden_states[~expand_attention_mask] = 0.0 + input_lengths = (attention_mask.long()).sum(-1) + # apply pooling formula to get real output_lengths + output_lengths = input_lengths // self.config.squeeze_factor + max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor + attention_ids = ( + torch.arange(0, max_encoder_length, device=output_lengths.device) + .view(1, -1) + .expand(output_lengths.shape[0], -1) + ) + attention_mask = (attention_ids < output_lengths.view(-1, 1)).long() + + # extend attention_mask + attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) + attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min + attention_mask = attention_mask.expand( + attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] + ) + + n_input_timesteps = hidden_states.shape[1] + + hidden_states = hidden_states.transpose(1, 2) + position_embeddings = self.pos_conv_embed(hidden_states) + pooled_hidden_states = self.pool(hidden_states) + min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1)) + hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length] + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self) + + for layer in self.layers: + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) + dropout_probability = torch.rand([]) + + skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False + if not skip_the_layer or synced_gpus: + # under fsdp or deepspeed zero3 all gpus must run in sync + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = layer( + hidden_states, attention_mask=attention_mask, output_attentions=output_attentions + ) + hidden_states = layer_outputs[0] + + if skip_the_layer: + layer_outputs = (None, None) + + if output_attentions: + all_self_attentions = all_self_attentions + (layer_outputs[1],) + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + hidden_states = self.upsample(hidden_states) + if hidden_states.shape[1] < n_input_timesteps: + hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1])) + + if not return_dict: + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +@auto_docstring +class SEWPreTrainedModel(PreTrainedModel): + config_class = SEWConfig + base_model_prefix = "sew" + main_input_name = "input_values" + supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, SEWPositionalConvEmbedding): + nn.init.normal_( + module.conv.weight, + mean=0, + std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)), + ) + nn.init.constant_(module.conv.bias, 0) + elif isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + elif isinstance(module, nn.Conv1d): + if is_deepspeed_zero3_enabled(): + import deepspeed + + if hasattr(module, "weight_v") and hasattr(module, "weight_g"): + with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0): + nn.init.kaiming_normal_(module.weight.data) + else: + nn.init.kaiming_normal_(module.weight.data) + + if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None: + module.bias.data.zero_() + + def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + """ + Computes the output length of the convolutional layers + """ + + def _conv_out_length(input_length, kernel_size, stride): + # 1D convolutional layer output length formula taken + # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html + return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1 + + for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): + input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + + return input_lengths + + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) + batch_size = attention_mask.shape[0] + + attention_mask = torch.zeros( + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device + ) + # these two operations makes sure that all values before the output lengths idxs are attended to + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() + return attention_mask + + +@auto_docstring +class SEWModel(SEWPreTrainedModel): + def __init__(self, config: SEWConfig): + super().__init__(config) + self.config = config + self.feature_extractor = SEWFeatureEncoder(config) + self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps) + + self.project_features = config.conv_dim[-1] != config.hidden_size + if self.project_features: + self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size) + self.feature_dropout = nn.Dropout(config.feat_proj_dropout) + + if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0: + self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_()) + + self.encoder = SEWEncoder(config) + + # Initialize weights and apply final processing + self.post_init() + + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states + def _mask_hidden_states( + self, + hidden_states: torch.FloatTensor, + mask_time_indices: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + ): + """ + Masks extracted features along time axis and/or along feature axis according to + [SpecAugment](https://arxiv.org/abs/1904.08779). + """ + + # `config.apply_spec_augment` can set masking to False + if not getattr(self.config, "apply_spec_augment", True): + return hidden_states + + # generate indices & apply SpecAugment along time axis + batch_size, sequence_length, hidden_size = hidden_states.size() + + if mask_time_indices is not None: + # apply SpecAugment along time axis with given mask_time_indices + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + elif self.config.mask_time_prob > 0 and self.training: + mask_time_indices = _compute_mask_indices( + (batch_size, sequence_length), + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + attention_mask=attention_mask, + min_masks=self.config.mask_time_min_masks, + ) + mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) + + if self.config.mask_feature_prob > 0 and self.training: + # generate indices & apply SpecAugment along feature axis + mask_feature_indices = _compute_mask_indices( + (batch_size, hidden_size), + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + min_masks=self.config.mask_feature_min_masks, + ) + mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool) + mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1) + hidden_states[mask_feature_indices] = 0 + + return hidden_states + + @auto_docstring + def forward( + self, + input_values: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + mask_time_indices: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict + masked extracted features in *config.proj_codevector_dim* space. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + extract_features = self.feature_extractor(input_values) + extract_features = extract_features.transpose(1, 2) + extract_features = self.layer_norm(extract_features) + + if self.project_features: + extract_features = self.feature_projection(extract_features) + hidden_states = self.feature_dropout(extract_features) + + if attention_mask is not None: + # compute reduced attention_mask corresponding to feature vectors + attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) + + hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices) + + encoder_outputs = self.encoder( + hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = encoder_outputs[0] + + if not return_dict: + return (hidden_states,) + encoder_outputs[1:] + + return BaseModelOutput( + last_hidden_state=hidden_states, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SEWForCTC(Wav2Vec2ForCTC): + pass + + +class SEWForSequenceClassification(Wav2Vec2ForSequenceClassification): + pass + + +__all__ = ["SEWForCTC", "SEWForSequenceClassification", "SEWModel", "SEWPreTrainedModel"] diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index b5842217fce8..1cb084c95014 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -15,7 +15,7 @@ """PyTorch Speech2Text model.""" import math -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch import nn @@ -23,21 +23,33 @@ from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, Seq2SeqLMOutput, Seq2SeqModelOutput, ) -from ...modeling_utils import PreTrainedModel +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack from ...utils import ( auto_docstring, + is_torch_flex_attn_available, logging, ) from .configuration_speech_to_text import Speech2TextConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -161,6 +173,34 @@ def create_position_ids_from_input_ids( return incremental_indices.long() + padding_idx +# Copied from transformers.models.bart.modeling_bart.eager_attn_forward +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Speech2Text class Speech2TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -196,6 +236,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -207,6 +248,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -214,10 +258,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -232,18 +280,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -255,72 +303,41 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) + else: + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value - - -SPEECH_TO_TEXT_ATTENTION_CLASSES = {"eager": Speech2TextAttention} + return attn_output, attn_weights, past_key_value # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT @@ -329,7 +346,7 @@ def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = Speech2TextAttention( embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, @@ -398,7 +415,7 @@ def __init__(self, config: Speech2TextConfig): super().__init__() self.embed_dim = config.d_model - self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + self.self_attn = Speech2TextAttention( embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, @@ -411,7 +428,7 @@ def __init__(self, config: Speech2TextConfig): self.activation_dropout = config.activation_dropout self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) - self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation]( + self.encoder_attn = Speech2TextAttention( self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, @@ -519,6 +536,9 @@ class Speech2TextPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "input_features" supports_gradient_checkpointing = True + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): std = self.config.init_std @@ -655,10 +675,12 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # expand attention_mask - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + _unsupported_features = output_attentions is True or head_mask is not None + attention_mask = self._update_full_mask( + attention_mask, + inputs_embeds, + _unsupported_features, + ) encoder_states = () if output_hidden_states else None all_attentions = () if output_attentions else None @@ -713,6 +735,34 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) + # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class Speech2TextDecoder(Speech2TextPreTrainedModel): """ @@ -857,16 +907,21 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + _unsupported_features, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, ) - - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) # embed positions positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length) @@ -963,6 +1018,92 @@ def forward( cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class Speech2TextModel(Speech2TextPreTrainedModel): diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 48f351e21dba..deac66f60216 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -7,7 +7,7 @@ import math import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -17,7 +17,8 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -25,13 +26,14 @@ SequenceClassifierOutput, Wav2Vec2BaseModelOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_unispeech import UniSpeechConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -263,6 +265,33 @@ def forward(self, hidden_states): return hidden_states, norm_hidden_states +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class UniSpeechAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -297,6 +326,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -308,6 +338,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -315,10 +348,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -333,18 +370,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -356,303 +393,43 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class UniSpeechFlashAttention2(UniSpeechAttention): - """ - UniSpeech flash attention module. This module inherits from `UniSpeechAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # UniSpeechFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("UniSpeechFlashAttention2 attention does not support output_attentions") - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class UniSpeechSdpaAttention(UniSpeechAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "UniSpeechModel is using UniSpeechSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - class UniSpeechFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -677,21 +454,15 @@ def forward(self, hidden_states): return hidden_states -UNISPEECH_ATTENTION_CLASSES = { - "eager": UniSpeechAttention, - "sdpa": UniSpeechSdpaAttention, - "flash_attention_2": UniSpeechFlashAttention2, -} - - class UniSpeechEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = UniSpeechAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -728,7 +499,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -745,16 +515,13 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + + _unsupported_features = output_attentions is True + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + _unsupported_features, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -803,6 +570,33 @@ def forward( attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class UniSpeechAttnAdapterLayer(nn.Module): def __init__(self, config): @@ -832,11 +626,12 @@ def forward(self, hidden_states: torch.FloatTensor): class UniSpeechEncoderLayerStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() - self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = UniSpeechAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -885,7 +680,6 @@ def __init__(self, config): [UniSpeechEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -899,19 +693,16 @@ def forward( all_self_attentions = () if output_attentions else None if attention_mask is not None: - # make sure padded tokens are not attended to + # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + hidden_states[~expand_attention_mask] = 0 + + _unsupported_features = output_attentions is True + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + _unsupported_features, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -962,6 +753,33 @@ def forward( attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class UniSpeechGumbelVectorQuantizer(nn.Module): """ @@ -1041,6 +859,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/unispeech/modular_unispeech.py b/src/transformers/models/unispeech/modular_unispeech.py index 9c1eae48e851..645a7e375bf7 100644 --- a/src/transformers/models/unispeech/modular_unispeech.py +++ b/src/transformers/models/unispeech/modular_unispeech.py @@ -135,6 +135,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index a24d50398775..66eabfe3eeef 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -7,7 +7,7 @@ import math import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -17,7 +17,8 @@ from ...activations import ACT2FN from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa +from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, CausalLMOutput, @@ -27,13 +28,14 @@ Wav2Vec2BaseModelOutput, XVectorOutput, ) -from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, is_peft_available, logging +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging from .configuration_unispeech_sat import UniSpeechSatConfig -if is_flash_attn_available(): - from ...modeling_flash_attention_utils import _flash_attention_forward +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -266,6 +268,33 @@ def forward(self, hidden_states): return hidden_states, norm_hidden_states +def eager_attn_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, +): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + + attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value) + attn_output = attn_output.transpose(1, 2).contiguous() + + return attn_output, attn_weights + + class UniSpeechSatAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -300,6 +329,7 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + # Kept for BC dependencies def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -311,6 +341,9 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + # TODO: we need a refactor so that the different attention modules can get their specific kwargs + # ATM, we have mixed things encoder, decoder, and encoder-decoder attn + **kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -318,10 +351,14 @@ def forward( # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() + # determine input shapes + bsz, tgt_len = hidden_states.shape[:-1] + q_input_shape = (bsz, tgt_len, -1, self.head_dim) + kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape # get query proj - query_states = self.q_proj(hidden_states) * self.scaling + query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) + # get key, value proj # `past_key_value[0].shape[2] == key_value_states.shape[1]` # is checking that the `sequence_length` of the `past_key_value` is the same as @@ -336,18 +373,18 @@ def forward( value_states = past_key_value[1] elif is_cross_attention: # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) + key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) elif past_key_value is not None: # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) else: # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) + key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) if self.is_decoder: # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. @@ -359,303 +396,43 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) - key_states = key_states.reshape(*proj_shape) - value_states = value_states.reshape(*proj_shape) - - src_len = key_states.size(1) - attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) - - if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): - raise ValueError( - f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" + attention_interface: Callable = eager_attn_forward + attention_type = self.config._attn_implementation + if self.config._attn_implementation != "eager": + if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." ) - attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - attn_weights = nn.functional.softmax(attn_weights, dim=-1) - - if layer_head_mask is not None: - if layer_head_mask.size() != (self.num_heads,): - raise ValueError( - f"Head mask for a single layer should be of size {(self.num_heads,)}, but is" - f" {layer_head_mask.size()}" + elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." ) - attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) - - if output_attentions: - # this operation is a bit awkward, but it's required to - # make sure that attn_weights keeps its gradient. - # In order to do so, attn_weights have to be reshaped - # twice and have to be reused in the following - attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) - attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) - else: - attn_weights_reshaped = None - - attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) - - attn_output = torch.bmm(attn_probs, value_states) - - if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, attn_weights_reshaped, past_key_value - - -class UniSpeechSatFlashAttention2(UniSpeechSatAttention): - """ - UniSpeechSat flash attention module. This module inherits from `UniSpeechSatAttention` as the weights of the module stays - untouched. The only required change would be on the forward pass where it needs to correctly call the public API of - flash attention and deal with padding tokens in case the input contains any of them. - """ - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. - # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. - # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). - self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # UniSpeechSatFlashAttention2 attention does not support output_attentions - if output_attentions: - raise ValueError("UniSpeechSatFlashAttention2 attention does not support output_attentions") - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, q_len, _ = hidden_states.size() - - # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] - - # In PEFT, usually we cast the layer norms in float32 for training stability reasons - # therefore the input hidden states gets silently casted in float32. Hence, we need - # cast them back in the correct dtype just to be sure everything works as expected. - # This might slowdown training & inference so it is recommended to not cast the LayerNorms - # in fp32. (LlamaRMSNorm handles it correctly) - - input_dtype = query_states.dtype - if input_dtype == torch.float32: - if torch.is_autocast_enabled(): - target_dtype = torch.get_autocast_gpu_dtype() - # Handle the case where the model is quantized - elif hasattr(self.config, "_pre_quantization_dtype"): - target_dtype = self.config._pre_quantization_dtype else: - target_dtype = self.q_proj.weight.dtype - - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) - - query_states = query_states.to(target_dtype) - key_states = key_states.to(target_dtype) - value_states = value_states.to(target_dtype) + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - attn_output = _flash_attention_forward( + attn_output, attn_weights = attention_interface( + self, query_states, key_states, value_states, attention_mask, - q_len, - dropout=self.dropout if self.training else 0.0, - is_causal=self.is_causal, - use_top_left_mask=self._flash_attn_uses_top_left_mask, + dropout=0.0 if not self.training else self.dropout, + scaling=self.scaling, + layer_head_mask=layer_head_mask, + **kwargs, ) - attn_output = attn_output.reshape(bsz, q_len, -1) + attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous() attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - return attn_output, attn_weights, past_key_value -class UniSpeechSatSdpaAttention(UniSpeechSatAttention): - def forward( - self, - hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - """Input shape: Batch x Time x Channel""" - if output_attentions or layer_head_mask is not None: - # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented. - logger.warning_once( - "UniSpeechSatModel is using UniSpeechSatSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` or `layer_head_mask` not None. Falling back to the manual attention" - ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states, - key_value_states=key_value_states, - past_key_value=past_key_value, - attention_mask=attention_mask, - layer_head_mask=layer_head_mask, - output_attentions=output_attentions, - ) - - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - - bsz, tgt_len, _ = hidden_states.size() - - # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) - - # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment - # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False - - # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, - # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 - attn_output = torch.nn.functional.scaled_dot_product_attention( - query_states, - key_states, - value_states, - attn_mask=attention_mask, - dropout_p=self.dropout if self.training else 0.0, - is_causal=is_causal, - ) - - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) - - # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be - # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - - attn_output = self.out_proj(attn_output) - - return attn_output, None, past_key_value - - class UniSpeechSatFeedForward(nn.Module): def __init__(self, config): super().__init__() @@ -680,21 +457,15 @@ def forward(self, hidden_states): return hidden_states -UNISPEECH_SAT_ATTENTION_CLASSES = { - "eager": UniSpeechSatAttention, - "sdpa": UniSpeechSatSdpaAttention, - "flash_attention_2": UniSpeechSatFlashAttention2, -} - - class UniSpeechSatEncoderLayer(nn.Module): def __init__(self, config): super().__init__() - self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = UniSpeechSatAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) @@ -731,7 +502,6 @@ def __init__(self, config): self.dropout = nn.Dropout(config.hidden_dropout) self.layers = nn.ModuleList([UniSpeechSatEncoderLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -748,16 +518,13 @@ def forward( # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + + _unsupported_features = output_attentions is True + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + _unsupported_features, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -806,6 +573,33 @@ def forward( attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class UniSpeechSatAttnAdapterLayer(nn.Module): def __init__(self, config): @@ -835,11 +629,12 @@ def forward(self, hidden_states: torch.FloatTensor): class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module): def __init__(self, config): super().__init__() - self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation]( + self.attention = UniSpeechSatAttention( embed_dim=config.hidden_size, num_heads=config.num_attention_heads, dropout=config.attention_dropout, is_decoder=False, + config=config, ) self.dropout = nn.Dropout(config.hidden_dropout) self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -888,7 +683,6 @@ def __init__(self, config): [UniSpeechSatEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)] ) self.gradient_checkpointing = False - self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" def forward( self, @@ -902,19 +696,16 @@ def forward( all_self_attentions = () if output_attentions else None if attention_mask is not None: - # make sure padded tokens are not attended to + # make sure padded tokens output 0 expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) - hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # extend attention_mask - attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype) - attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min - attention_mask = attention_mask.expand( - attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1] - ) + hidden_states[~expand_attention_mask] = 0 + + _unsupported_features = output_attentions is True + attention_mask = self._update_full_mask( + attention_mask, + hidden_states, + _unsupported_features, + ) position_embeddings = self.pos_conv_embed(hidden_states) hidden_states = hidden_states + position_embeddings @@ -965,6 +756,33 @@ def forward( attentions=all_self_attentions, ) + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + class UniSpeechSatGumbelVectorQuantizer(nn.Module): """ @@ -1044,6 +862,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py index 22de13c3bdc0..d38678286a29 100644 --- a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py @@ -145,6 +145,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py index fc660d06ac9a..af02fa91ee1d 100644 --- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py @@ -26,7 +26,11 @@ XVectorOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import ModelOutput, auto_docstring, is_peft_available +from ...utils import ( + ModelOutput, + auto_docstring, + is_peft_available, +) from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index 90330c264525..11670ea7d21b 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -617,6 +617,7 @@ class WavLMPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = False _supports_sdpa = False + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/wavlm/modular_wavlm.py b/src/transformers/models/wavlm/modular_wavlm.py index b74d9fc703e0..7a3b887bb67e 100644 --- a/src/transformers/models/wavlm/modular_wavlm.py +++ b/src/transformers/models/wavlm/modular_wavlm.py @@ -527,6 +527,8 @@ class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = False _supports_sdpa = False + _supports_sdpa = False + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 59a999c2abf3..1f7781084ed8 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -27,7 +27,10 @@ from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import AttentionMaskConverter -from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available +from ...modeling_flash_attention_utils import ( + flash_attn_supports_top_left_mask, + is_flash_attn_available, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -570,7 +573,8 @@ def forward( } -# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER +# (BC Dep) Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER +# TODO(vasqu): fix copies when enabling whisper attn interface class WhisperEncoderLayer(nn.Module): def __init__(self, config: WhisperConfig): super().__init__() From dccabeb4117520aba2e1e0404eac5013004460e0 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 14 May 2025 17:06:19 +0200 Subject: [PATCH 16/68] remove unnecessary _reshape function and remove copy to whisper --- src/transformers/models/bart/modeling_bart.py | 4 ---- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 4 ---- src/transformers/models/biogpt/modeling_biogpt.py | 4 ---- src/transformers/models/blenderbot/modeling_blenderbot.py | 4 ---- .../models/blenderbot_small/modeling_blenderbot_small.py | 4 ---- src/transformers/models/data2vec/modeling_data2vec_audio.py | 4 ---- src/transformers/models/hubert/modeling_hubert.py | 4 ---- src/transformers/models/informer/modeling_informer.py | 4 ---- src/transformers/models/m2m_100/modeling_m2m_100.py | 4 ---- src/transformers/models/marian/modeling_marian.py | 4 ---- src/transformers/models/mbart/modeling_mbart.py | 4 ---- src/transformers/models/musicgen/modeling_musicgen.py | 4 ---- .../models/musicgen_melody/modeling_musicgen_melody.py | 4 ---- src/transformers/models/nllb_moe/modeling_nllb_moe.py | 4 ---- src/transformers/models/patchtsmixer/modeling_patchtsmixer.py | 4 ---- src/transformers/models/patchtst/modeling_patchtst.py | 4 ---- src/transformers/models/pegasus/modeling_pegasus.py | 4 ---- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 4 ---- src/transformers/models/plbart/modeling_plbart.py | 4 ---- src/transformers/models/qwen2_audio/modeling_qwen2_audio.py | 1 - src/transformers/models/sew/modeling_sew.py | 4 ---- .../models/speech_to_text/modeling_speech_to_text.py | 4 ---- .../modeling_time_series_transformer.py | 4 ---- src/transformers/models/unispeech/modeling_unispeech.py | 4 ---- .../models/unispeech_sat/modeling_unispeech_sat.py | 4 ---- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 4 ---- src/transformers/models/whisper/modeling_whisper.py | 1 - 27 files changed, 102 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index b3612f69f473..cdfd286a07fd 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -167,10 +167,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 031a73145f99..bc1ab64a6e5f 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1230,10 +1230,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index eba6dfc498d0..5ca8cebbe549 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -148,10 +148,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 8a4e09841683..e15e8f9c7f6b 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -166,10 +166,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 0b5488d5c00e..99b291e79c52 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -150,10 +150,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 5b908f0c2612..fcf4513cbd01 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -230,10 +230,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 5fee0c483bff..92672ec011b7 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -287,10 +287,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 5247c2b90a6d..3dc70d6dd12c 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -311,10 +311,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index f47e45183cec..ebce0ca7d07f 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -239,10 +239,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 2779259d543b..a7c7f47f8afe 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -164,10 +164,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 758de3695197..68e9eb5a2c01 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -174,10 +174,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 7f66abd77a32..0414093b3346 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -213,10 +213,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 9eb3672e7f2a..ea7b00bca349 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -229,10 +229,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index b227057d1d30..46e30d034234 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -534,10 +534,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 93ee81b727e0..1c270d088c0b 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -304,10 +304,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 45cf9893c2d2..7670055f46ae 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -97,10 +97,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 63508f5b3b0e..ef01f9e5e8d7 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -165,10 +165,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 3bdafa78488a..b13fa247ec89 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -190,10 +190,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 7fa87b6fb381..eb7042db94d8 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -175,10 +175,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index f9cbb6b9b8ad..d1938b4664e2 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -127,7 +127,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 76eab46d3a0f..85f75c1586db 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -280,10 +280,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 1cb084c95014..dae7e3136f46 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -236,10 +236,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 8f5b87e648f3..e3c432561861 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -339,10 +339,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index deac66f60216..f57601095ed6 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -326,10 +326,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 66eabfe3eeef..18058a9dec4d 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -329,10 +329,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 870a6f4eeaa5..54c060dc2b62 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -534,10 +534,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Kept for BC dependencies - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 1f7781084ed8..b595f2bc8cb8 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -262,7 +262,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() From cecd0a415a88504e3f96c7a3e7c3df0907f7c1aa Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 14 May 2025 17:15:59 +0200 Subject: [PATCH 17/68] add skip for decoder-only models out of enc-dec (same as in bart) --- tests/models/blenderbot/test_modeling_blenderbot.py | 4 ++++ .../models/blenderbot_small/test_modeling_blenderbot_small.py | 4 ++++ tests/models/marian/test_modeling_marian.py | 4 ++++ tests/models/pegasus/test_modeling_pegasus.py | 4 ++++ tests/models/pegasus_x/test_modeling_pegasus_x.py | 4 ++++ tests/models/plbart/test_modeling_plbart.py | 4 ++++ 6 files changed, 24 insertions(+) diff --git a/tests/models/blenderbot/test_modeling_blenderbot.py b/tests/models/blenderbot/test_modeling_blenderbot.py index 945506eeb3ca..0c54544ca66f 100644 --- a/tests/models/blenderbot/test_modeling_blenderbot.py +++ b/tests/models/blenderbot/test_modeling_blenderbot.py @@ -554,3 +554,7 @@ def test_decoder_model_attn_mask_past(self): @unittest.skip(reason="decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py index 8bdf116647a6..e7916c6c1b69 100644 --- a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py +++ b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py @@ -563,3 +563,7 @@ def test_decoder_model_attn_mask_past(self): @unittest.skip(reason="decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index 52f2232ee69a..ec013467432a 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -850,3 +850,7 @@ def test_decoder_model_attn_mask_past(self): @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/models/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py index d0b86d5fa5b5..9f0ed926ef28 100644 --- a/tests/models/pegasus/test_modeling_pegasus.py +++ b/tests/models/pegasus/test_modeling_pegasus.py @@ -597,3 +597,7 @@ def test_decoder_model_attn_mask_past(self): @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/models/pegasus_x/test_modeling_pegasus_x.py b/tests/models/pegasus_x/test_modeling_pegasus_x.py index a1439ebaabc5..c1862d9301ee 100644 --- a/tests/models/pegasus_x/test_modeling_pegasus_x.py +++ b/tests/models/pegasus_x/test_modeling_pegasus_x.py @@ -872,3 +872,7 @@ def test_decoder_model_attn_mask_past(self): @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return diff --git a/tests/models/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py index 7255f2b7c530..303e5b734531 100644 --- a/tests/models/plbart/test_modeling_plbart.py +++ b/tests/models/plbart/test_modeling_plbart.py @@ -677,3 +677,7 @@ def test_decoder_model_attn_mask_past(self): @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return From ac61dd79f7c80e0ec0743e4919e450af4dafa3c4 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 14 May 2025 17:42:59 +0200 Subject: [PATCH 18/68] bring back licences --- .../models/data2vec/modeling_data2vec_audio.py | 15 +++++++++++++++ .../models/data2vec/modular_data2vec_audio.py | 17 +++++++++++++++++ .../models/hubert/modeling_hubert.py | 15 +++++++++++++++ .../models/hubert/modular_hubert.py | 16 ++++++++++++++++ .../models/informer/modeling_informer.py | 15 +++++++++++++++ .../models/informer/modular_informer.py | 16 ++++++++++++++++ .../models/sew/feature_extractor_sew.py | 14 ++++++++++++++ src/transformers/models/sew/modeling_sew.py | 15 +++++++++++++++ src/transformers/models/sew/modular_sew.py | 16 ++++++++++++++++ .../models/unispeech/modeling_unispeech.py | 15 +++++++++++++++ .../models/unispeech/modular_unispeech.py | 16 ++++++++++++++++ .../unispeech_sat/modeling_unispeech_sat.py | 15 +++++++++++++++ .../unispeech_sat/modular_unispeech_sat.py | 16 ++++++++++++++++ src/transformers/models/wavlm/modular_wavlm.py | 1 - 14 files changed, 201 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index fcf4513cbd01..9969ac180a10 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -4,6 +4,21 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_data2vec_audio.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import warnings from typing import Callable, Optional, Tuple, Union diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py index 58934d2e86a7..0b4695c1e28c 100644 --- a/src/transformers/models/data2vec/modular_data2vec_audio.py +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Data2VecText model.""" + import math import torch @@ -124,6 +140,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 92672ec011b7..c2540b3970e1 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -4,6 +4,21 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_hubert.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import warnings from typing import Callable, Optional, Tuple, Union diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py index 5286f9065b58..c0454452f029 100644 --- a/src/transformers/models/hubert/modular_hubert.py +++ b/src/transformers/models/hubert/modular_hubert.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Hubert model.""" + from typing import Optional, Tuple, Union import torch diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 3dc70d6dd12c..78f297564c2b 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -4,6 +4,21 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_informer.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import Callable, List, Optional, Tuple, Union import numpy as np diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 80b8f20966d5..d54b8b94c7b8 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Informer model.""" + from typing import Optional, Tuple, Union import numpy as np diff --git a/src/transformers/models/sew/feature_extractor_sew.py b/src/transformers/models/sew/feature_extractor_sew.py index b65a2e8f0cbd..c58812b58d6b 100644 --- a/src/transformers/models/sew/feature_extractor_sew.py +++ b/src/transformers/models/sew/feature_extractor_sew.py @@ -4,6 +4,20 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_sew.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import warnings from .modeling_sew import SEWFeatureEncoder diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 85f75c1586db..752da48cc35e 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -4,6 +4,21 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_sew.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import warnings from typing import Callable, Optional, Tuple, Union diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py index 0f1ef81e7246..69caa26d3a4a 100644 --- a/src/transformers/models/sew/modular_sew.py +++ b/src/transformers/models/sew/modular_sew.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch SEW model.""" + import math import warnings from typing import Optional, Tuple, Union diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index f57601095ed6..ec2b2d73f7a1 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -4,6 +4,21 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_unispeech.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import warnings from dataclasses import dataclass diff --git a/src/transformers/models/unispeech/modular_unispeech.py b/src/transformers/models/unispeech/modular_unispeech.py index 645a7e375bf7..5a9133089aeb 100644 --- a/src/transformers/models/unispeech/modular_unispeech.py +++ b/src/transformers/models/unispeech/modular_unispeech.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch UniSpeech model.""" + import math import warnings from dataclasses import dataclass diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 18058a9dec4d..61a69dc88af8 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -4,6 +4,21 @@ # the file from the modular. If any change should be done, please apply the change to the # modular_unispeech_sat.py file directly. One of our CI enforces this. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import math import warnings from dataclasses import dataclass diff --git a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py index d38678286a29..f86c397a047c 100644 --- a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py @@ -1,3 +1,19 @@ +# coding=utf-8 +# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch UniSpeechSat model.""" + import math import warnings from dataclasses import dataclass diff --git a/src/transformers/models/wavlm/modular_wavlm.py b/src/transformers/models/wavlm/modular_wavlm.py index 7a3b887bb67e..1ff9d5052c0a 100644 --- a/src/transformers/models/wavlm/modular_wavlm.py +++ b/src/transformers/models/wavlm/modular_wavlm.py @@ -527,7 +527,6 @@ class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = False _supports_sdpa = False - _supports_sdpa = False _supports_flex_attn = False def _init_weights(self, module): From a6e848d1bad4693070cc69abca5ef6eb7ee9455e Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 14 May 2025 17:49:24 +0200 Subject: [PATCH 19/68] remove comment, added to pr read instead --- src/transformers/models/bart/configuration_bart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/bart/configuration_bart.py b/src/transformers/models/bart/configuration_bart.py index 1dc4c6101d0b..4ce4316e3c03 100644 --- a/src/transformers/models/bart/configuration_bart.py +++ b/src/transformers/models/bart/configuration_bart.py @@ -107,7 +107,6 @@ class BartConfig(PretrainedConfig): model_type = "bart" keys_to_ignore_at_inference = ["past_key_values"] attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"} - # TODO: add tp plan def __init__( self, From ddfc515ed7d04bbe8e6c5ac2a5afe361551e62be Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 10:21:49 +0200 Subject: [PATCH 20/68] mostly docs --- docs/source/en/model_doc/biogpt.md | 11 ++++++----- docs/source/en/model_doc/blenderbot-small.md | 4 +++- docs/source/en/model_doc/blenderbot.md | 6 ++++-- docs/source/en/model_doc/marian.md | 4 +++- docs/source/en/model_doc/nllb-moe.md | 10 ++++++---- docs/source/en/model_doc/pegasus.md | 2 ++ docs/source/en/model_doc/pegasus_x.md | 2 ++ docs/source/en/model_doc/plbart.md | 6 ++++-- docs/source/en/model_doc/speech_to_text.md | 2 ++ docs/source/en/model_doc/time_series_transformer.md | 2 ++ .../musicgen_melody/modeling_musicgen_melody.py | 1 + src/transformers/models/sew/modeling_sew.py | 1 + src/transformers/models/sew/modular_sew.py | 1 + 13 files changed, 37 insertions(+), 15 deletions(-) diff --git a/docs/source/en/model_doc/biogpt.md b/docs/source/en/model_doc/biogpt.md index ab8aea6c29e8..fc4979b75ba2 100644 --- a/docs/source/en/model_doc/biogpt.md +++ b/docs/source/en/model_doc/biogpt.md @@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention SDPA
@@ -39,13 +40,13 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The ### Using Scaled Dot Product Attention (SDPA) -PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function -encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the -[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) +PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function +encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the +[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html) or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention) page for more information. -SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set +SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set `attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used. ``` @@ -108,7 +109,7 @@ we saw the following speedups during inference. [[autodoc]] BioGptForCausalLM - forward - + ## BioGptForTokenClassification [[autodoc]] BioGptForTokenClassification diff --git a/docs/source/en/model_doc/blenderbot-small.md b/docs/source/en/model_doc/blenderbot-small.md index 647a865de339..341e43c03040 100644 --- a/docs/source/en/model_doc/blenderbot-small.md +++ b/docs/source/en/model_doc/blenderbot-small.md @@ -21,6 +21,8 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention +SDPA Note that [`BlenderbotSmallModel`] and @@ -52,7 +54,7 @@ found [here](https://github.com/facebookresearch/ParlAI). ## Usage tips -Blenderbot Small is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than +Blenderbot Small is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left. diff --git a/docs/source/en/model_doc/blenderbot.md b/docs/source/en/model_doc/blenderbot.md index ec24d5ed7495..adfa6841e10a 100644 --- a/docs/source/en/model_doc/blenderbot.md +++ b/docs/source/en/model_doc/blenderbot.md @@ -21,6 +21,8 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention +SDPA ## Overview @@ -45,7 +47,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The ## Usage tips and example -Blenderbot is a model with absolute position embeddings so it's usually advised to pad the inputs on the right +Blenderbot is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than the left. An example: @@ -71,7 +73,7 @@ An example: `facebook/blenderbot_small_90M`, have a different architecture and consequently should be used with [BlenderbotSmall](blenderbot-small). - + ## Resources - [Causal language modeling task guide](../tasks/language_modeling) diff --git a/docs/source/en/model_doc/marian.md b/docs/source/en/model_doc/marian.md index 80bb73d26df1..4fcd6363559c 100644 --- a/docs/source/en/model_doc/marian.md +++ b/docs/source/en/model_doc/marian.md @@ -21,6 +21,8 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention +SDPA ## Overview @@ -155,7 +157,7 @@ Example of translating english to many romance languages, using old-style 2 char >>> model = MarianMTModel.from_pretrained(model_name) >>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True)) >>> tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated] -["c'est une phrase en anglais que nous voulons traduire en français", +["c'est une phrase en anglais que nous voulons traduire en français", 'Isto deve ir para o português.', 'Y esto al español'] ``` diff --git a/docs/source/en/model_doc/nllb-moe.md b/docs/source/en/model_doc/nllb-moe.md index 65a4812ed6ab..3f0be7a7c96e 100644 --- a/docs/source/en/model_doc/nllb-moe.md +++ b/docs/source/en/model_doc/nllb-moe.md @@ -18,6 +18,8 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention +SDPA
## Overview @@ -51,10 +53,10 @@ The original code can be found [here](https://github.com/facebookresearch/fairse ## Implementation differences with SwitchTransformers -The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the -highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed, -which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden -states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism. +The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the +highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed, +which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden +states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism. ## Generating with NLLB-MoE diff --git a/docs/source/en/model_doc/pegasus.md b/docs/source/en/model_doc/pegasus.md index bdb61e66d984..5681ac9b58a0 100644 --- a/docs/source/en/model_doc/pegasus.md +++ b/docs/source/en/model_doc/pegasus.md @@ -21,6 +21,8 @@ rendered properly in your Markdown viewer. TensorFlow Flax +FlashAttention +SDPA ## Overview diff --git a/docs/source/en/model_doc/pegasus_x.md b/docs/source/en/model_doc/pegasus_x.md index 3f982263cdb1..c32a173a6672 100644 --- a/docs/source/en/model_doc/pegasus_x.md +++ b/docs/source/en/model_doc/pegasus_x.md @@ -18,6 +18,8 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention +SDPA
## Overview diff --git a/docs/source/en/model_doc/plbart.md b/docs/source/en/model_doc/plbart.md index bac567615d42..d57ee8ed99e8 100644 --- a/docs/source/en/model_doc/plbart.md +++ b/docs/source/en/model_doc/plbart.md @@ -18,6 +18,8 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention +SDPA
## Overview @@ -29,7 +31,7 @@ on Java, Python and English. According to the abstract *Code summarization and generation empower conversion between programming language (PL) and natural language (NL), -while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART, +while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART, a sequence-to-sequence model capable of performing a broad spectrum of program and language understanding and generation tasks. PLBART is pre-trained on an extensive collection of Java and Python functions and associated NL text via denoising autoencoding. Experiments on code summarization in the English language, code generation, and code translation in seven programming languages @@ -50,7 +52,7 @@ target text format is `[tgt_lang_code] X [eos]`. `bos` is never used. However, for fine-tuning, in some cases no language token is provided in cases where a single language is used. Please refer to [the paper](https://arxiv.org/abs/2103.06333) to learn more about this. -In cases where the language code is needed, the regular [`~PLBartTokenizer.__call__`] will encode source text format +In cases where the language code is needed, the regular [`~PLBartTokenizer.__call__`] will encode source text format when you pass texts as the first argument or with the keyword argument `text`, and will encode target text format if it's passed with the `text_target` keyword argument. diff --git a/docs/source/en/model_doc/speech_to_text.md b/docs/source/en/model_doc/speech_to_text.md index bc65ea79655f..3e4680ca8dd2 100644 --- a/docs/source/en/model_doc/speech_to_text.md +++ b/docs/source/en/model_doc/speech_to_text.md @@ -19,6 +19,8 @@ rendered properly in your Markdown viewer.
PyTorch TensorFlow +FlashAttention +SDPA
## Overview diff --git a/docs/source/en/model_doc/time_series_transformer.md b/docs/source/en/model_doc/time_series_transformer.md index a91633b6b029..a06f2578d3d5 100644 --- a/docs/source/en/model_doc/time_series_transformer.md +++ b/docs/source/en/model_doc/time_series_transformer.md @@ -18,6 +18,8 @@ rendered properly in your Markdown viewer.
PyTorch +FlashAttention +SDPA
## Overview diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index ea7b00bca349..15c1476ced3e 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -1309,6 +1309,7 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def __init__( self, diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 752da48cc35e..b4993fe4f725 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -573,6 +573,7 @@ class SEWPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py index 69caa26d3a4a..e2f2633c0942 100644 --- a/src/transformers/models/sew/modular_sew.py +++ b/src/transformers/models/sew/modular_sew.py @@ -393,6 +393,7 @@ class SEWPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + _supports_flex_attn = True def _init_weights(self, module): """Initialize the weights""" From 3e5da38678c557b1f35e80e052dbc08fa53724d8 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 11:10:52 +0200 Subject: [PATCH 21/68] disable sew flex attn as it's unclear attn mask for now --- src/transformers/models/sew/modeling_sew.py | 2 +- src/transformers/models/sew/modular_sew.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index b4993fe4f725..aff13a19f350 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -573,7 +573,7 @@ class SEWPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + _supports_flex_attn = False # needs a proper look into the mask creation def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py index e2f2633c0942..97a8255b3d01 100644 --- a/src/transformers/models/sew/modular_sew.py +++ b/src/transformers/models/sew/modular_sew.py @@ -393,7 +393,7 @@ class SEWPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + _supports_flex_attn = False # needs a proper look into the mask creation def _init_weights(self, module): """Initialize the weights""" From 9a9b1404cbfb768d9097615424153a1b27bd520e Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 12:16:33 +0200 Subject: [PATCH 22/68] oops --- .../modeling_blenderbot_small.py | 88 ++++++++++++++++++- 1 file changed, 84 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 99b291e79c52..84f007417960 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -816,10 +816,6 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, @@ -979,6 +975,90 @@ def _update_causal_mask( return attention_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + @auto_docstring class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): From aecd5e2bdfb88c1d850ed48f24dc413980c1d47c Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 12:26:03 +0200 Subject: [PATCH 23/68] test fixes for enc-dec --- tests/test_modeling_common.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index ff86e157fdb7..8dead13e820b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4267,7 +4267,16 @@ def test_flash_attn_2_from_config(self): dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) - _ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask) + if config.is_encoder_decoder: + _ = fa2_model( + input_ids=dummy_input, + attention_mask=dummy_attention_mask, + decoder_input_ids=dummy_input.clone(), + decoder_attention_mask=dummy_attention_mask.clone(), + ) + else: + _ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask) + with tempfile.TemporaryDirectory() as tmpdirname: fa2_model.save_pretrained(tmpdirname) model_from_pretrained = model_class.from_pretrained(tmpdirname) @@ -4502,7 +4511,14 @@ def test_flex_attention_with_grads(self): self.assertTrue(model.config._attn_implementation == "flex_attention") # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) - _ = model(inputs_dict["input_ids"].to(torch_device)) + dummy_input = inputs_dict["input_ids"].to(torch_device) + if config.is_encoder_decoder: + _ = model( + input_ids=dummy_input, + decoder_input_ids=dummy_input.clone(), + ) + else: + _ = model(input_ids=dummy_input) def test_generation_tester_mixin_inheritance(self): """ From 7bdb6920b5a9264345103af99171907a0567073b Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 13:14:03 +0200 Subject: [PATCH 24/68] torch fx fixes + try at flex attn --- src/transformers/models/bart/modeling_bart.py | 8 +++++++- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 8 +++++++- src/transformers/models/biogpt/modeling_biogpt.py | 8 +++++++- .../models/blenderbot/modeling_blenderbot.py | 8 +++++++- .../blenderbot_small/modeling_blenderbot_small.py | 10 +++++++++- .../models/data2vec/modeling_data2vec_audio.py | 8 +++++++- src/transformers/models/hubert/modeling_hubert.py | 8 +++++++- src/transformers/models/informer/modeling_informer.py | 8 +++++++- src/transformers/models/m2m_100/modeling_m2m_100.py | 8 +++++++- src/transformers/models/marian/modeling_marian.py | 8 +++++++- src/transformers/models/mbart/modeling_mbart.py | 8 +++++++- src/transformers/models/musicgen/modeling_musicgen.py | 8 +++++++- .../models/musicgen_melody/modeling_musicgen_melody.py | 8 +++++++- src/transformers/models/nllb_moe/modeling_nllb_moe.py | 10 +++++++--- .../models/patchtsmixer/modeling_patchtsmixer.py | 8 +++++++- src/transformers/models/patchtst/modeling_patchtst.py | 8 +++++++- src/transformers/models/pegasus/modeling_pegasus.py | 8 +++++++- .../models/pegasus_x/modeling_pegasus_x.py | 8 +++++++- src/transformers/models/plbart/modeling_plbart.py | 8 +++++++- src/transformers/models/sew/modeling_sew.py | 8 +++++++- .../models/speech_to_text/modeling_speech_to_text.py | 8 +++++++- .../modeling_time_series_transformer.py | 8 +++++++- .../models/unispeech/modeling_unispeech.py | 8 +++++++- .../models/unispeech_sat/modeling_unispeech_sat.py | 8 +++++++- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 8 +++++++- tests/test_modeling_common.py | 2 +- 26 files changed, 178 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index cdfd286a07fd..14014b7b8683 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -187,8 +187,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index bc1ab64a6e5f..7b4bd6d879af 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1250,8 +1250,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 5ca8cebbe549..b62af5f5907f 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -168,8 +168,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index e15e8f9c7f6b..a6b371ba686e 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -186,8 +186,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 84f007417960..51edac906587 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -170,8 +170,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) @@ -975,6 +981,7 @@ def _update_causal_mask( return attention_mask + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, None], @@ -1019,6 +1026,7 @@ def _update_causal_mask( return attention_mask + # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask def _update_cross_attn_mask( self, encoder_hidden_states: Union[torch.Tensor, None], diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 9969ac180a10..d3f05c9bbb11 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -265,8 +265,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index c2540b3970e1..4928df5d0e15 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -322,8 +322,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 78f297564c2b..680ede0cf34a 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -346,8 +346,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index ebce0ca7d07f..e458ef4bbdce 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -259,8 +259,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index a7c7f47f8afe..5e7fc03ede9c 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -184,8 +184,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 68e9eb5a2c01..c6ac9eea45b1 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -194,8 +194,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 0414093b3346..157bb6bbd652 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -233,8 +233,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 15c1476ced3e..f7f111b7f29a 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -249,8 +249,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 46e30d034234..b7b394e83982 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -554,10 +554,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = encoder_hidden_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = ( - (*encoder_hidden_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape - ) + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 1c270d088c0b..598a6458b41d 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -324,8 +324,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 7670055f46ae..6a7b1e178d9b 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -117,8 +117,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index ef01f9e5e8d7..cf46f1ac74a6 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -185,8 +185,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index b13fa247ec89..2c5743d4834e 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -210,8 +210,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index eb7042db94d8..ada52b245dba 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -195,8 +195,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index aff13a19f350..96268da7559e 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -315,8 +315,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index dae7e3136f46..61a971401dda 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -256,8 +256,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index e3c432561861..f0f520164bc7 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -359,8 +359,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index ec2b2d73f7a1..e71d13edc537 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -361,8 +361,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 61a69dc88af8..72e92206527b 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -364,8 +364,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 54c060dc2b62..e1757cba4557 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -554,8 +554,14 @@ def forward( # determine input shapes bsz, tgt_len = hidden_states.shape[:-1] + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + + # certain models do not have a sequence per se + src_len = max(src_len, 1) + tgt_len = max(tgt_len, 1) + q_input_shape = (bsz, tgt_len, -1, self.head_dim) - kv_input_shape = (*key_value_states.shape[:-1], -1, self.head_dim) if is_cross_attention else q_input_shape + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 8dead13e820b..fd1f7381f47e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4511,7 +4511,7 @@ def test_flex_attention_with_grads(self): self.assertTrue(model.config._attn_implementation == "flex_attention") # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) - dummy_input = inputs_dict["input_ids"].to(torch_device) + dummy_input = inputs_dict[model_class.main_input_name].to(torch_device) if config.is_encoder_decoder: _ = model( input_ids=dummy_input, From f8260e6748e36cf8e6ce633bcbe52d90ccce62da Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 13:29:06 +0200 Subject: [PATCH 25/68] skip on mbart --- tests/models/mbart/test_modeling_mbart.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py index b80d11a8650f..579637aadd8c 100644 --- a/tests/models/mbart/test_modeling_mbart.py +++ b/tests/models/mbart/test_modeling_mbart.py @@ -748,3 +748,7 @@ def test_decoder_model_attn_mask_past(self): @unittest.skip(reason="Decoder cannot retain gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot retain gradients") + def test_flex_attention_with_grads(self): + return From 598a5669d21b22951798141d80548596b1c746a7 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 15:05:58 +0200 Subject: [PATCH 26/68] some more fixes --- tests/test_modeling_common.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index fd1f7381f47e..414ee14aad6b 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4507,18 +4507,12 @@ def test_flex_attention_with_grads(self): self.skipTest(reason="This model does not support flex attention") config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config._attn_implementation = "flex_attention" - model = model_class(config).to(device=torch_device, dtype=torch.float16) + model = model_class(config).to(device=torch_device) self.assertTrue(model.config._attn_implementation == "flex_attention") # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) dummy_input = inputs_dict[model_class.main_input_name].to(torch_device) - if config.is_encoder_decoder: - _ = model( - input_ids=dummy_input, - decoder_input_ids=dummy_input.clone(), - ) - else: - _ = model(input_ids=dummy_input) + _ = model(dummy_input) def test_generation_tester_mixin_inheritance(self): """ From 61b648f10b16b2b5d7fa6ee4bd95b3381ae2436b Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 15:35:28 +0200 Subject: [PATCH 27/68] musicgen skip / delete old attn class logic + sdpa compose compile skip --- .../models/musicgen/test_modeling_musicgen.py | 43 +++---------------- .../test_modeling_musicgen_melody.py | 43 +++---------------- 2 files changed, 14 insertions(+), 72 deletions(-) diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index e80bb7948a14..2f2057493fcc 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -1062,30 +1062,7 @@ def test_flash_attn_2_inference_equivalence(self): @mark.flash_attn_test @slow def test_flash_attn_2_conversion(self): - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") - - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, - ).to(torch_device) - - for _, module in model.named_modules(): - if "FlashAttention" in module.__class__.__name__: - return - - self.assertTrue(False, "FlashAttention2 modules not found in model") + self.skipTest(reason="Musicgen doesn't use the MusicgenFlashAttention2 class method.") @require_torch_sdpa @require_torch_gpu @@ -1260,18 +1237,6 @@ def test_sdpa_can_dispatch_composite_models(self): self.assertTrue(model_eager.decoder.config._attn_implementation == "eager") self.assertTrue(model_eager.config._attn_implementation == "eager") - for name, submodule in model_eager.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - def test_requires_grad_with_frozen_encoders(self): config = self.model_tester.get_config() for model_class in self.all_model_classes: @@ -1302,6 +1267,12 @@ def test_requires_grad_with_frozen_encoders(self): def test_generation_tester_mixin_inheritance(self): pass + @unittest.skip( + reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5.") + ) + def test_sdpa_can_compile_dynamic(self): + pass + def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000): """Produces a series of 'bip bip' sounds at a given frequency.""" diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index bf441bb19e6f..8d0e822625b7 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -1052,30 +1052,7 @@ def test_flash_attn_2_inference_equivalence(self): @mark.flash_attn_test @slow def test_flash_attn_2_conversion(self): - if not self.has_attentions: - self.skipTest(reason="Model architecture does not support attentions") - - config, _ = self.model_tester.prepare_config_and_inputs_for_common() - - for model_class in self.all_model_classes: - if not model_class._supports_flash_attn_2: - self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") - - model = model_class(config) - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_pretrained(tmpdirname) - model = model_class.from_pretrained( - tmpdirname, - torch_dtype=torch.float16, - attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None}, - ).to(torch_device) - - for _, module in model.named_modules(): - if "FlashAttention" in module.__class__.__name__: - return - - self.assertTrue(False, "FlashAttention2 modules not found in model") + self.skipTest(reason="Musicgen doesn't use the MusicgenFlashAttention2 class method.") @require_torch_sdpa @require_torch_gpu @@ -1250,18 +1227,6 @@ def test_sdpa_can_dispatch_composite_models(self): self.assertTrue(model_eager.decoder.config._attn_implementation == "eager") self.assertTrue(model_eager.config._attn_implementation == "eager") - for name, submodule in model_eager.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - raise ValueError("The eager model should not have SDPA attention layers") - - has_sdpa = False - for name, submodule in model_sdpa.named_modules(): - if "SdpaAttention" in submodule.__class__.__name__: - has_sdpa = True - break - if not has_sdpa and model_sdpa.config.model_type != "falcon": - raise ValueError("The SDPA model should have SDPA attention layers") - def test_requires_grad_with_frozen_encoders(self): config = self.model_tester.get_config() for model_class in self.all_model_classes: @@ -1292,6 +1257,12 @@ def test_requires_grad_with_frozen_encoders(self): def test_generation_tester_mixin_inheritance(self): pass + @unittest.skip( + reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5.") + ) + def test_sdpa_can_compile_dynamic(self): + pass + # Copied from tests.models.musicgen.test_modeling_musicgen.get_bip_bip def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000): From 43169910a1a78fb66c70c2e1dc4a142bb823ccbb Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 15:57:24 +0200 Subject: [PATCH 28/68] disable flex attn for musicgen, not worth the effort --- src/transformers/models/musicgen/modeling_musicgen.py | 5 ++++- .../models/musicgen_melody/modeling_musicgen_melody.py | 6 ++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 157bb6bbd652..6a51e0327b84 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -450,7 +450,8 @@ class MusicgenPreTrainedModel(PreTrainedModel): _no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # compilation errors occurr atm + _supports_flex_attn = False def _init_weights(self, module): std = self.config.initializer_factor @@ -1389,6 +1390,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True + # compilation errors occurr atm + _supports_flex_attn = False def __init__( self, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index f7f111b7f29a..8ed4dbcc2fac 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -425,7 +425,8 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel): _no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # compilation errors occurr atm + _supports_flex_attn = False def _init_weights(self, module): std = self.config.initializer_factor @@ -1315,7 +1316,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # compilation errors occurr atm + _supports_flex_attn = False def __init__( self, From 69371062e37d530fb43f7ffee1a7524913b09593 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 16:09:25 +0200 Subject: [PATCH 29/68] more fixes and style --- tests/models/m2m_100/test_modeling_m2m_100.py | 2 +- tests/models/musicgen/test_modeling_musicgen.py | 4 +--- .../models/musicgen_melody/test_modeling_musicgen_melody.py | 6 ++---- 3 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py index 1c33150adaf3..e85ef91a167f 100644 --- a/tests/models/m2m_100/test_modeling_m2m_100.py +++ b/tests/models/m2m_100/test_modeling_m2m_100.py @@ -439,7 +439,7 @@ def test_flash_attn_2_seq_to_seq_generation(self): Overwriting the common test as the test is flaky on tiny models """ model = M2M100ForConditionalGeneration.from_pretrained( - "facebook/m2m100_418M", attn_implementation="flash_attention_2" + "facebook/m2m100_418M", attn_implementation="flash_attention_2", torch_dtype=torch.float16 ).to(torch_device) tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en") diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 2f2057493fcc..5377f97b5ed2 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -1267,9 +1267,7 @@ def test_requires_grad_with_frozen_encoders(self): def test_generation_tester_mixin_inheritance(self): pass - @unittest.skip( - reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5.") - ) + @unittest.skip(reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5.")) def test_sdpa_can_compile_dynamic(self): pass diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index 8d0e822625b7..5eaea27f5d7a 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -1052,7 +1052,7 @@ def test_flash_attn_2_inference_equivalence(self): @mark.flash_attn_test @slow def test_flash_attn_2_conversion(self): - self.skipTest(reason="Musicgen doesn't use the MusicgenFlashAttention2 class method.") + self.skipTest(reason="MusicgenMelody doesn't use the MusicgenMelodyFlashAttention2 class method.") @require_torch_sdpa @require_torch_gpu @@ -1257,9 +1257,7 @@ def test_requires_grad_with_frozen_encoders(self): def test_generation_tester_mixin_inheritance(self): pass - @unittest.skip( - reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5.") - ) + @unittest.skip(reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5.")) def test_sdpa_can_compile_dynamic(self): pass From 4f1234740c43a57631c67d36b28fbb08303622f7 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 16:29:05 +0200 Subject: [PATCH 30/68] flex attention test for dropout and encoder decoder that dont have main input names --- tests/test_modeling_common.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 414ee14aad6b..a637f25efa3c 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4510,9 +4510,17 @@ def test_flex_attention_with_grads(self): model = model_class(config).to(device=torch_device) self.assertTrue(model.config._attn_implementation == "flex_attention") + # Elaborate workaround for encoder-decoder models as some do not specify their main input + dummy_input = {"input_ids": inputs_dict[model_class.main_input_name].to(torch_device)} + if "decoder_input_ids" in inspect.signature(model.forward).parameters: + dummy_input["decoder_input_ids"] = dummy_input["input_ids"].clone() + + # Flex Attention can not use dropout + if hasattr(config, "attention_droput"): + config.attention_droput = 0 + # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) - dummy_input = inputs_dict[model_class.main_input_name].to(torch_device) - _ = model(dummy_input) + _ = model(**dummy_input) def test_generation_tester_mixin_inheritance(self): """ From 05e38b129221df2fce0693d16904b8ef5d92ba8c Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 16:48:15 +0200 Subject: [PATCH 31/68] informer fixes --- src/transformers/models/informer/modeling_informer.py | 6 ++++-- src/transformers/models/informer/modular_informer.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 680ede0cf34a..5ef1443bfb90 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -776,15 +776,17 @@ def __init__(self, config: InformerConfig): if config.attention_type == "prob": self.self_attn = InformerProbSparseAttention( embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, + num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, sampling_factor=config.sampling_factor, + is_decoder=True, ) else: self.self_attn = InformerAttention( embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, + num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, + is_decoder=True, config=config, ) diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index d54b8b94c7b8..88ce69085d2c 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -376,15 +376,17 @@ def __init__(self, config: InformerConfig): if config.attention_type == "prob": self.self_attn = InformerProbSparseAttention( embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, + num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, sampling_factor=config.sampling_factor, + is_decoder=True, ) else: self.self_attn = InformerAttention( embed_dim=self.embed_dim, - num_heads=config.encoder_attention_heads, + num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, + is_decoder=True, config=config, ) From 9a8d4e476a77451e5d29b67cd5150df9df5efac2 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 17:56:29 +0200 Subject: [PATCH 32/68] the weirdest thing I've encountered yet... --- .../models/informer/modeling_informer.py | 11 ++++++++++- .../modeling_time_series_transformer.py | 13 ++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 5ef1443bfb90..82c449b9d3a2 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1679,7 +1679,16 @@ def forward( attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - dec_input = transformer_inputs[:, self.config.context_length :, ...] + # Avoid empty tensors and instead create a zeroes tensor which + # will be treated the same in torch, i.e. matmul with empty == all 0s + if self.config.context_length >= transformer_inputs.shape[1]: + bsz, _, dim = transformer_inputs.shape + dec_input = torch.zeros( + size=(bsz, 1, dim), device=transformer_inputs.device, dtype=transformer_inputs.dtype + ) + else: + dec_input = transformer_inputs[:, self.config.context_length :, ...] + decoder_outputs = self.decoder( inputs_embeds=dec_input, attention_mask=decoder_attention_mask, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index f0f520164bc7..fca8c49d8e4d 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -1452,7 +1452,18 @@ def forward( attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, ) - dec_input = transformer_inputs[:, self.config.context_length :, ...] + # Avoid empty tensors and instead create a zeroes tensor which + # will be treated the same in torch, i.e. matmul with empty == all 0s + if self.config.context_length >= transformer_inputs.shape[1]: + bsz, _, dim = transformer_inputs.shape + dec_input = torch.zeros( + size=(bsz, 1, dim), + device=transformer_inputs.device, + dtype=transformer_inputs.dtype + ) + else: + dec_input = transformer_inputs[:, self.config.context_length :, ...] + decoder_outputs = self.decoder( inputs_embeds=dec_input, attention_mask=decoder_attention_mask, From 2055759f518ab9a8e73f7ede38e1ad9d8fcf3609 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 18:00:10 +0200 Subject: [PATCH 33/68] style --- .../modeling_time_series_transformer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index fca8c49d8e4d..c5d4cc159684 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -1457,9 +1457,7 @@ def forward( if self.config.context_length >= transformer_inputs.shape[1]: bsz, _, dim = transformer_inputs.shape dec_input = torch.zeros( - size=(bsz, 1, dim), - device=transformer_inputs.device, - dtype=transformer_inputs.dtype + size=(bsz, 1, dim), device=transformer_inputs.device, dtype=transformer_inputs.dtype ) else: dec_input = transformer_inputs[:, self.config.context_length :, ...] From adc808d2c37112dfe8cfe1292a13ec7d5481a294 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 15 May 2025 18:10:29 +0200 Subject: [PATCH 34/68] remove empty tensor attempt, found core root in previous commits --- src/transformers/models/bart/modeling_bart.py | 4 ---- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 4 ---- src/transformers/models/biogpt/modeling_biogpt.py | 4 ---- src/transformers/models/blenderbot/modeling_blenderbot.py | 4 ---- .../models/blenderbot_small/modeling_blenderbot_small.py | 4 ---- src/transformers/models/data2vec/modeling_data2vec_audio.py | 4 ---- src/transformers/models/hubert/modeling_hubert.py | 4 ---- src/transformers/models/informer/modeling_informer.py | 4 ---- src/transformers/models/m2m_100/modeling_m2m_100.py | 4 ---- src/transformers/models/marian/modeling_marian.py | 4 ---- src/transformers/models/mbart/modeling_mbart.py | 4 ---- src/transformers/models/musicgen/modeling_musicgen.py | 4 ---- .../models/musicgen_melody/modeling_musicgen_melody.py | 4 ---- src/transformers/models/nllb_moe/modeling_nllb_moe.py | 4 ---- src/transformers/models/patchtsmixer/modeling_patchtsmixer.py | 4 ---- src/transformers/models/patchtst/modeling_patchtst.py | 4 ---- src/transformers/models/pegasus/modeling_pegasus.py | 4 ---- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 4 ---- src/transformers/models/plbart/modeling_plbart.py | 4 ---- src/transformers/models/sew/modeling_sew.py | 4 ---- .../models/speech_to_text/modeling_speech_to_text.py | 4 ---- .../modeling_time_series_transformer.py | 4 ---- src/transformers/models/unispeech/modeling_unispeech.py | 4 ---- .../models/unispeech_sat/modeling_unispeech_sat.py | 4 ---- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 4 ---- 25 files changed, 100 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 14014b7b8683..a7bc4542c7ba 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -189,10 +189,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 7b4bd6d879af..9aac9fc320de 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1252,10 +1252,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index b62af5f5907f..446aeb46abd1 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -170,10 +170,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index a6b371ba686e..e71e47ce1af9 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -188,10 +188,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 51edac906587..61d2d948f599 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -172,10 +172,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index d3f05c9bbb11..676466ca972a 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -267,10 +267,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 4928df5d0e15..5c95c8ee9747 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -324,10 +324,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 82c449b9d3a2..33bbc8ba15f9 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -348,10 +348,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index e458ef4bbdce..3cfa33f1f1d0 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -261,10 +261,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 5e7fc03ede9c..7d4886ee3d0a 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -186,10 +186,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index c6ac9eea45b1..90018b43fe57 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -196,10 +196,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 6a51e0327b84..d834f5fc115b 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -235,10 +235,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 8ed4dbcc2fac..fdd8ea52336b 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -251,10 +251,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index b7b394e83982..b22c2d138f83 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -556,10 +556,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = encoder_hidden_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 598a6458b41d..b3e07918865c 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -326,10 +326,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 6a7b1e178d9b..808ba3e95cf4 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -119,10 +119,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index cf46f1ac74a6..ec2805ec5441 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -187,10 +187,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 2c5743d4834e..1e3b5ebf3432 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -212,10 +212,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index ada52b245dba..6239f786acfb 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -197,10 +197,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 96268da7559e..761003d8bbef 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -317,10 +317,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 61a971401dda..4216fb417e98 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -258,10 +258,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index c5d4cc159684..a06e4e65e284 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -361,10 +361,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index e71d13edc537..33ad6bcd7c44 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -363,10 +363,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 72e92206527b..0794a7898578 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -366,10 +366,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index e1757cba4557..66340be5122e 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -556,10 +556,6 @@ def forward( bsz, tgt_len = hidden_states.shape[:-1] src_len = key_value_states.shape[1] if is_cross_attention else tgt_len - # certain models do not have a sequence per se - src_len = max(src_len, 1) - tgt_len = max(tgt_len, 1) - q_input_shape = (bsz, tgt_len, -1, self.head_dim) kv_input_shape = (bsz, src_len, -1, self.head_dim) From 3d23455a168317c492a461992982197b3a8a5728 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Fri, 16 May 2025 10:11:05 +0200 Subject: [PATCH 35/68] disable time series due to tests being very text centric on inputs --- .../modeling_time_series_transformer.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index a06e4e65e284..4c773244bde6 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -639,9 +639,11 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "past_values" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True + # TODO: tests would need a rewrite to check for correct implementation + # Current tests always assume certain inputs to be passed + _supports_flash_attn_2 = False + _supports_sdpa = False + _supports_flex_attn = False def _init_weights(self, module): std = self.config.init_std From 8f9de8681b21102243beab1fc442199ef22488b2 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Fri, 16 May 2025 10:28:32 +0200 Subject: [PATCH 36/68] add speech to text to be ignoring the other attns, also due to tests --- .../models/speech_to_text/modeling_speech_to_text.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 4216fb417e98..60d7bf17617f 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -534,9 +534,11 @@ class Speech2TextPreTrainedModel(PreTrainedModel): base_model_prefix = "model" main_input_name = "input_features" supports_gradient_checkpointing = True - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True + # TODO: tests would need a rewrite to check for correct implementation + # Current tests always assume certain inputs to be passed + _supports_flash_attn_2 = False + _supports_sdpa = False + _supports_flex_attn = False def _init_weights(self, module): std = self.config.init_std From b94c96640b3719a0fbc8b286a0eb63d576dc4f8a Mon Sep 17 00:00:00 2001 From: Vasqu Date: Fri, 16 May 2025 10:30:29 +0200 Subject: [PATCH 37/68] update docs --- docs/source/en/model_doc/speech_to_text.md | 2 -- docs/source/en/model_doc/time_series_transformer.md | 2 -- 2 files changed, 4 deletions(-) diff --git a/docs/source/en/model_doc/speech_to_text.md b/docs/source/en/model_doc/speech_to_text.md index 3e4680ca8dd2..bc65ea79655f 100644 --- a/docs/source/en/model_doc/speech_to_text.md +++ b/docs/source/en/model_doc/speech_to_text.md @@ -19,8 +19,6 @@ rendered properly in your Markdown viewer.
PyTorch TensorFlow -FlashAttention -SDPA
## Overview diff --git a/docs/source/en/model_doc/time_series_transformer.md b/docs/source/en/model_doc/time_series_transformer.md index a06f2578d3d5..a91633b6b029 100644 --- a/docs/source/en/model_doc/time_series_transformer.md +++ b/docs/source/en/model_doc/time_series_transformer.md @@ -18,8 +18,6 @@ rendered properly in your Markdown viewer.
PyTorch -FlashAttention -SDPA
## Overview From 6f813cd16abb23423bef9bf3d406c411fab7a235 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Fri, 16 May 2025 12:08:35 +0200 Subject: [PATCH 38/68] remaining issues resolved ? --- src/transformers/models/nllb_moe/modeling_nllb_moe.py | 9 ++++++--- .../models/patchtsmixer/modeling_patchtsmixer.py | 6 +----- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 3 ++- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index b22c2d138f83..b03194519737 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -860,9 +860,12 @@ class NllbMoePreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["NllbMoeEncoderLayer", "NllbMoeDecoderLayer"] - _supports_flash_attn_2 = True - _supports_sdpa = True - _supports_flex_attn = True + # TODO: If anyone is up to it to make sure tests pass etc + # Flash attention has problems due to not preparing masks the same way as eager/sdpa + # SDPA has more flaky logits which requires more time to look into tests + _supports_flash_attn_2 = False + _supports_sdpa = False + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index b3e07918865c..249bb71d3ab1 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -28,14 +28,10 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput -from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils import auto_docstring, logging from .configuration_patchtsmixer import PatchTSMixerConfig -if is_torch_flex_attn_available(): - pass - - logger = logging.get_logger(__name__) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 1e3b5ebf3432..b78589f57c73 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -762,7 +762,8 @@ class PegasusXPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"] _supports_flash_attn_2 = True - _supports_sdpa = True + # TODO: Flaky logits + _supports_sdpa = False _supports_flex_attn = True def _init_weights(self, module): From 3be2a9d43c00d024171afbd7b38507f385ad8655 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Fri, 16 May 2025 12:19:18 +0200 Subject: [PATCH 39/68] update docs for current state --> nllb moe and pegasus x sdpa is questionable :D --- docs/source/en/model_doc/nllb-moe.md | 2 -- docs/source/en/model_doc/pegasus_x.md | 1 - 2 files changed, 3 deletions(-) diff --git a/docs/source/en/model_doc/nllb-moe.md b/docs/source/en/model_doc/nllb-moe.md index 3f0be7a7c96e..fc8c8c92115d 100644 --- a/docs/source/en/model_doc/nllb-moe.md +++ b/docs/source/en/model_doc/nllb-moe.md @@ -18,8 +18,6 @@ rendered properly in your Markdown viewer.
PyTorch -FlashAttention -SDPA
## Overview diff --git a/docs/source/en/model_doc/pegasus_x.md b/docs/source/en/model_doc/pegasus_x.md index c32a173a6672..97e50601b725 100644 --- a/docs/source/en/model_doc/pegasus_x.md +++ b/docs/source/en/model_doc/pegasus_x.md @@ -19,7 +19,6 @@ rendered properly in your Markdown viewer.
PyTorch FlashAttention -SDPA
## Overview From 6dbd77ade2f2ee93984e1ea371a7c57f9097bcb9 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Fri, 16 May 2025 13:15:47 +0200 Subject: [PATCH 40/68] some models have not set the is_causal flag... --- src/transformers/integrations/sdpa_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 5c14df042e89..515eef5ae988 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -45,7 +45,7 @@ def sdpa_attention_forward( if is_causal is None: # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns - is_causal = query.shape[2] > 1 and attention_mask is None and module.is_causal + is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True) # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor. # We convert it to a bool for the SDPA kernel that only accepts bools. From dd3d3077754a765d275e0be0506b16d52d87bcda Mon Sep 17 00:00:00 2001 From: Vasqu Date: Fri, 16 May 2025 15:23:22 +0200 Subject: [PATCH 41/68] change dtype in softmax tol old behaviour + some modular fixes --- src/transformers/models/bart/modeling_bart.py | 2 +- .../modeling_bigbird_pegasus.py | 2 +- .../models/biogpt/modeling_biogpt.py | 2 +- .../models/blenderbot/modeling_blenderbot.py | 2 +- .../modeling_blenderbot_small.py | 2 +- .../data2vec/modeling_data2vec_audio.py | 2 +- .../models/hubert/modeling_hubert.py | 2 +- .../models/informer/modeling_informer.py | 2 +- .../models/m2m_100/modeling_m2m_100.py | 2 +- .../models/marian/modeling_marian.py | 2 +- .../models/mbart/modeling_mbart.py | 2 +- .../models/musicgen/modeling_musicgen.py | 2 +- .../modeling_musicgen_melody.py | 2 +- .../models/nllb_moe/modeling_nllb_moe.py | 2 +- .../patchtsmixer/modeling_patchtsmixer.py | 2 +- .../models/patchtst/modeling_patchtst.py | 2 +- .../models/pegasus/modeling_pegasus.py | 2 +- .../models/pegasus_x/modeling_pegasus_x.py | 2 +- .../models/plbart/modeling_plbart.py | 10 +- .../models/plbart/modular_plbart.py | 145 +----------------- src/transformers/models/sew/modeling_sew.py | 3 +- src/transformers/models/sew/modular_sew.py | 122 +-------------- .../speech_to_text/modeling_speech_to_text.py | 2 +- .../modeling_time_series_transformer.py | 2 +- .../models/unispeech/modeling_unispeech.py | 2 +- .../unispeech_sat/modeling_unispeech_sat.py | 2 +- .../models/wav2vec2/modeling_wav2vec2.py | 2 +- 27 files changed, 35 insertions(+), 291 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index a7bc4542c7ba..d99a9610c493 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -121,7 +121,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 9aac9fc320de..4e5e7e089e4d 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1183,7 +1183,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 446aeb46abd1..0873745a632f 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -102,7 +102,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index e71e47ce1af9..7c293ab1b290 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -119,7 +119,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 61d2d948f599..345f2b2cfe1f 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -103,7 +103,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 676466ca972a..462fdf56023d 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -199,7 +199,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 5c95c8ee9747..6048be6b59d4 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -256,7 +256,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 33bbc8ba15f9..44c9a758f120 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -280,7 +280,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 3cfa33f1f1d0..314fe51f9598 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -192,7 +192,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 7d4886ee3d0a..0ddafe75287c 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -117,7 +117,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 90018b43fe57..3cc766757f51 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -127,7 +127,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index d834f5fc115b..acb03d9d3a10 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -166,7 +166,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index fdd8ea52336b..3eafc8017b76 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -182,7 +182,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index b03194519737..452c89ff969d 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -487,7 +487,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 249bb71d3ab1..ee243a6c79a7 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -253,7 +253,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 808ba3e95cf4..807c56e0833d 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -50,7 +50,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index ec2805ec5441..f755a2b9d15f 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -118,7 +118,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index b78589f57c73..8a16e232981e 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -143,7 +143,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 6239f786acfb..96f90eeab3ce 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -129,7 +129,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) @@ -1022,10 +1022,9 @@ def _update_cross_attn_mask( return encoder_attention_mask -# Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): """ - Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not + Shift input ids one token to the right, and wrap the last non pad token (the token) Note that PLBart does not have a single `decoder_start_token_id` in contrast to other Bart-like models. """ prev_output_tokens = input_ids.clone() @@ -1395,8 +1394,8 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @auto_docstring( custom_intro=""" - PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for code - classification. + PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. + for GLUE tasks. """ ) class PLBartForSequenceClassification(PLBartPreTrainedModel): @@ -1416,7 +1415,6 @@ def __init__(self, config: PLBartConfig, **kwargs): self.post_init() @auto_docstring - # Ignore copy def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index 27152541d000..0d0b9fd644da 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -20,14 +20,13 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from torch.nn import CrossEntropyLoss from ...generation import GenerationMixin from ...modeling_outputs import ( BaseModelOutput, Seq2SeqLMOutput, Seq2SeqModelOutput, - Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring @@ -38,30 +37,11 @@ BartForCausalLM, BartScaledWordEmbedding, ) +from ..bigbird_pegasus.modeling_bigbird_pegasus import BigBirdPegasusForSequenceClassification +from ..mbart.modeling_mbart import shift_tokens_right from .configuration_plbart import PLBartConfig -# Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right -def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): - """ - Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not - have a single `decoder_start_token_id` in contrast to other Bart-like models. - """ - prev_output_tokens = input_ids.clone() - - if pad_token_id is None: - raise ValueError("self.model.config.pad_token_id has to be defined.") - # replace possible -100 values in labels by `pad_token_id` - prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id) - - index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1) - decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze() - prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone() - prev_output_tokens[:, 0] = decoder_start_tokens - - return prev_output_tokens - - class PLBartScaledWordEmbedding(BartScaledWordEmbedding): pass @@ -426,48 +406,8 @@ class PLBartClassificationHead(BartClassificationHead): pass -@auto_docstring( - custom_intro=""" - PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for code - classification. - """ -) -class PLBartForSequenceClassification(PLBartPreTrainedModel): - _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"] - - def __init__(self, config: PLBartConfig, **kwargs): - super().__init__(config, **kwargs) - self.model = PLBartModel(config) - self.classification_head = PLBartClassificationHead( - config.d_model, - config.d_model, - config.num_labels, - config.classifier_dropout, - ) - - # Initialize weights and apply final processing - self.post_init() - - @auto_docstring - # Ignore copy - def forward( - self, - input_ids: Optional[torch.LongTensor] = None, - attention_mask: Optional[torch.Tensor] = None, - decoder_input_ids: Optional[torch.LongTensor] = None, - decoder_attention_mask: Optional[torch.LongTensor] = None, - head_mask: Optional[torch.Tensor] = None, - decoder_head_mask: Optional[torch.Tensor] = None, - cross_attn_head_mask: Optional[torch.Tensor] = None, - encoder_outputs: Optional[List[torch.FloatTensor]] = None, - inputs_embeds: Optional[torch.FloatTensor] = None, - decoder_inputs_embeds: Optional[torch.FloatTensor] = None, - labels: Optional[torch.LongTensor] = None, - use_cache: Optional[bool] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: +class PLBartForSequenceClassification(BigBirdPegasusForSequenceClassification): + def forward(**super_kwargs): r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): Indices of decoder input sequence tokens in the vocabulary. @@ -500,80 +440,7 @@ def forward( Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy). """ - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if labels is not None: - use_cache = False - - if input_ids is None and inputs_embeds is not None: - raise NotImplementedError( - f"Passing input embeddings is currently not supported for {self.__class__.__name__}" - ) - - outputs = self.model( - input_ids, - attention_mask=attention_mask, - decoder_input_ids=decoder_input_ids, - decoder_attention_mask=decoder_attention_mask, - head_mask=head_mask, - decoder_head_mask=decoder_head_mask, - cross_attn_head_mask=cross_attn_head_mask, - encoder_outputs=encoder_outputs, - inputs_embeds=inputs_embeds, - decoder_inputs_embeds=decoder_inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - hidden_states = outputs[0] # last hidden state - - eos_mask = input_ids.eq(self.config.eos_token_id).to(hidden_states.device) - - if len(torch.unique_consecutive(eos_mask.sum(1))) > 1: - raise ValueError("All examples must have the same number of tokens.") - sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[ - :, -1, : - ] - logits = self.classification_head(sentence_representation) - - loss = None - if labels is not None: - labels = labels.to(logits.device) - if self.config.problem_type is None: - if self.config.num_labels == 1: - self.config.problem_type = "regression" - elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): - self.config.problem_type = "single_label_classification" - else: - self.config.problem_type = "multi_label_classification" - - if self.config.problem_type == "regression": - loss_fct = MSELoss() - if self.config.num_labels == 1: - loss = loss_fct(logits.squeeze(), labels.squeeze()) - else: - loss = loss_fct(logits, labels) - elif self.config.problem_type == "single_label_classification": - loss_fct = CrossEntropyLoss() - loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) - elif self.config.problem_type == "multi_label_classification": - loss_fct = BCEWithLogitsLoss() - loss = loss_fct(logits, labels) - if not return_dict: - output = (logits,) + outputs[1:] - return ((loss,) + output) if loss is not None else output - - return Seq2SeqSequenceClassifierOutput( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - decoder_hidden_states=outputs.decoder_hidden_states, - decoder_attentions=outputs.decoder_attentions, - cross_attentions=outputs.cross_attentions, - encoder_last_hidden_state=outputs.encoder_last_hidden_state, - encoder_hidden_states=outputs.encoder_hidden_states, - encoder_attentions=outputs.encoder_attentions, - ) + super().forward(**super_kwargs) class PLBartForCausalLM(BartForCausalLM): diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 761003d8bbef..999c9b23c65b 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -249,7 +249,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) @@ -637,7 +637,6 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti return attention_mask -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py index 97a8255b3d01..1f8b236e8760 100644 --- a/src/transformers/models/sew/modular_sew.py +++ b/src/transformers/models/sew/modular_sew.py @@ -18,7 +18,6 @@ import warnings from typing import Optional, Tuple, Union -import numpy as np import torch import torch.utils.checkpoint from torch import nn @@ -40,6 +39,7 @@ Wav2Vec2LayerNormConvLayer, Wav2Vec2NoLayerNormConvLayer, Wav2Vec2SamePadLayer, + _compute_mask_indices, ) from .configuration_sew import SEWConfig @@ -47,126 +47,6 @@ _HIDDEN_STATES_START_POSITION = 1 -# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices -def _compute_mask_indices( - shape: Tuple[int, int], - mask_prob: float, - mask_length: int, - attention_mask: Optional[torch.LongTensor] = None, - min_masks: int = 0, -) -> np.ndarray: - """ - Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for - ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on - CPU as part of the preprocessing during training. - - Args: - shape: The shape for which to compute masks. This should be of a tuple of size 2 where - the first element is the batch size and the second element is the length of the axis to span. - mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of - independently generated mask spans of length `mask_length` is computed by - `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the - actual percentage will be smaller. - mask_length: size of the mask - min_masks: minimum number of masked spans - attention_mask: A (right-padded) attention mask which independently shortens the feature axis of - each batch dimension. - """ - batch_size, sequence_length = shape - - if mask_length < 1: - raise ValueError("`mask_length` has to be bigger than 0.") - - if mask_length > sequence_length: - raise ValueError( - f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}" - f" and `sequence_length`: {sequence_length}`" - ) - - # epsilon is used for probabilistic rounding - epsilon = np.random.rand(1).item() - - def compute_num_masked_span(input_length): - """Given input length, compute how many spans should be masked""" - num_masked_span = int(mask_prob * input_length / mask_length + epsilon) - num_masked_span = max(num_masked_span, min_masks) - - # make sure num masked span <= sequence_length - if num_masked_span * mask_length > sequence_length: - num_masked_span = sequence_length // mask_length - - # make sure num_masked span is also <= input_length - (mask_length - 1) - if input_length - (mask_length - 1) < num_masked_span: - num_masked_span = max(input_length - (mask_length - 1), 0) - - return num_masked_span - - # compute number of masked spans in batch - input_lengths = ( - attention_mask.detach().sum(-1).tolist() - if attention_mask is not None - else [sequence_length for _ in range(batch_size)] - ) - - # SpecAugment mask to fill - spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool) - spec_aug_mask_idxs = [] - - max_num_masked_span = compute_num_masked_span(sequence_length) - - if max_num_masked_span == 0: - return spec_aug_mask - - for input_length in input_lengths: - # compute num of masked spans for this input - num_masked_span = compute_num_masked_span(input_length) - - # get random indices to mask - spec_aug_mask_idx = np.random.choice( - np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False - ) - - # pick first sampled index that will serve as a dummy index to pad vector - # to ensure same dimension for all batches due to probabilistic rounding - # Picking first sample just pads those vectors twice. - if len(spec_aug_mask_idx) == 0: - # this case can only happen if `input_length` is strictly smaller then - # `sequence_length` in which case the last token has to be a padding - # token which we can use as a dummy mask id - dummy_mask_idx = sequence_length - 1 - else: - dummy_mask_idx = spec_aug_mask_idx[0] - - spec_aug_mask_idx = np.concatenate( - [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx] - ) - spec_aug_mask_idxs.append(spec_aug_mask_idx) - - spec_aug_mask_idxs = np.array(spec_aug_mask_idxs) - - # expand masked indices to masked spans - spec_aug_mask_idxs = np.broadcast_to( - spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length) - ) - spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length) - - # add offset to the starting indexes so that indexes now create a span - offsets = np.arange(mask_length)[None, None, :] - offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape( - batch_size, max_num_masked_span * mask_length - ) - spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - - # ensure that we cannot have indices larger than sequence_length - if spec_aug_mask_idxs.max() > sequence_length - 1: - spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1 - - # scatter indices to mask - np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1) - - return spec_aug_mask - - class SEWNoLayerNormConvLayer(Wav2Vec2NoLayerNormConvLayer): pass diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 60d7bf17617f..5a90308198ef 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -189,7 +189,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 4c773244bde6..cc3e99c72dd0 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -292,7 +292,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 33ad6bcd7c44..b49b5ffbec49 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -295,7 +295,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 0794a7898578..f41b1bd9c629 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -298,7 +298,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 66340be5122e..efc0bcd4371a 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -487,7 +487,7 @@ def eager_attn_forward( if attention_mask is not None: attn_weights = attn_weights + attention_mask - attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = nn.functional.softmax(attn_weights, dim=-1) if layer_head_mask is not None: attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) From d77ea86f57706144c32b7e189aec9285af4f70e0 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Fri, 16 May 2025 16:13:58 +0200 Subject: [PATCH 42/68] I hate it but it is what it is --- tests/test_modeling_common.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index a637f25efa3c..c306399b3f33 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4511,9 +4511,12 @@ def test_flex_attention_with_grads(self): self.assertTrue(model.config._attn_implementation == "flex_attention") # Elaborate workaround for encoder-decoder models as some do not specify their main input - dummy_input = {"input_ids": inputs_dict[model_class.main_input_name].to(torch_device)} - if "decoder_input_ids" in inspect.signature(model.forward).parameters: - dummy_input["decoder_input_ids"] = dummy_input["input_ids"].clone() + if "input_ids" in inspect.signature(model.forward).parameters: + dummy_input = {"input_ids": inputs_dict[model_class.main_input_name].to(torch_device)} + if "decoder_input_ids" in inspect.signature(model.forward).parameters: + dummy_input["decoder_input_ids"] = dummy_input["input_ids"].clone() + else: + dummy_input = {model_class.main_input_name: inputs_dict[model_class.main_input_name].to(torch_device)} # Flex Attention can not use dropout if hasattr(config, "attention_droput"): From 71f7f1b205b3637900eedf4c0e25fd98d544e9dc Mon Sep 17 00:00:00 2001 From: Vasqu Date: Fri, 16 May 2025 16:46:23 +0200 Subject: [PATCH 43/68] fixes from main for bart --- tests/models/bart/test_modeling_bart.py | 33 ++++++------------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 6399cdf8725d..34fbcd80f9bf 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -58,28 +58,16 @@ def prepare_bart_inputs_dict( decoder_input_ids=None, attention_mask=None, decoder_attention_mask=None, - head_mask=None, - decoder_head_mask=None, - cross_attn_head_mask=None, ): if attention_mask is None: attention_mask = input_ids.ne(config.pad_token_id) if decoder_attention_mask is None: decoder_attention_mask = decoder_input_ids.ne(config.pad_token_id) - if head_mask is None: - head_mask = torch.ones(config.encoder_layers, config.encoder_attention_heads, device=torch_device) - if decoder_head_mask is None: - decoder_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) - if cross_attn_head_mask is None: - cross_attn_head_mask = torch.ones(config.decoder_layers, config.decoder_attention_heads, device=torch_device) return { "input_ids": input_ids, "decoder_input_ids": decoder_input_ids, "attention_mask": attention_mask, "decoder_attention_mask": attention_mask, - "head_mask": head_mask, - "decoder_head_mask": decoder_head_mask, - "cross_attn_head_mask": cross_attn_head_mask, } @@ -99,7 +87,7 @@ def __init__( hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=20, + max_position_embeddings=50, eos_token_id=2, pad_token_id=1, bos_token_id=0, @@ -167,10 +155,9 @@ def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): model = BartModel(config=config).get_decoder().to(torch_device).eval() input_ids = inputs_dict["input_ids"] attention_mask = inputs_dict["attention_mask"] - head_mask = inputs_dict["head_mask"] # first forward pass - outputs = model(input_ids, attention_mask=attention_mask, head_mask=head_mask, use_cache=True) + outputs = model(input_ids, attention_mask=attention_mask, use_cache=True) output, past_key_values = outputs.to_tuple() @@ -1177,8 +1164,7 @@ def test_cnn_summarization_same_as_fairseq(self): [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY], max_length=1024, padding="max_length", - truncation_strategy="only_first", - truncation=True, + truncation="only_first", return_tensors="pt", ) @@ -1314,7 +1300,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -1378,6 +1364,7 @@ def prepare_config_and_inputs(self): decoder_start_token_id=self.decoder_start_token_id, max_position_embeddings=self.max_position_embeddings, is_encoder_decoder=self.is_encoder_decoder, + forced_eos_token_id=None, ) return ( @@ -1478,9 +1465,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() @@ -1534,7 +1521,3 @@ def test_decoder_model_attn_mask_past(self): @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return - - @unittest.skip(reason="Decoder cannot keep gradients") - def test_flex_attention_with_grads(): - return From 8a43566d31f0329af188994a8aba57edaa0ddea8 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Fri, 16 May 2025 16:47:38 +0200 Subject: [PATCH 44/68] forgot this one --- tests/models/bart/test_modeling_bart.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 34fbcd80f9bf..ded8d5f0a8e8 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -1521,3 +1521,7 @@ def test_decoder_model_attn_mask_past(self): @unittest.skip(reason="Decoder cannot keep gradients") def test_retain_grad_hidden_states_attentions(self): return + + @unittest.skip(reason="Decoder cannot keep gradients") + def test_flex_attention_with_grads(): + return From 270c42ab2891e530590f7ff158a9ec5ec1b915c6 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Fri, 16 May 2025 17:36:42 +0200 Subject: [PATCH 45/68] some model fixes --- tests/models/biogpt/test_modeling_biogpt.py | 18 +++++++----------- .../models/informer/test_modeling_informer.py | 3 ++- tests/models/m2m_100/test_modeling_m2m_100.py | 4 ++-- .../patchtsmixer/test_modeling_patchtsmixer.py | 2 +- .../pegasus_x/test_modeling_pegasus_x.py | 7 +++---- 5 files changed, 15 insertions(+), 19 deletions(-) diff --git a/tests/models/biogpt/test_modeling_biogpt.py b/tests/models/biogpt/test_modeling_biogpt.py index ec6c64b55685..232f7176b23c 100644 --- a/tests/models/biogpt/test_modeling_biogpt.py +++ b/tests/models/biogpt/test_modeling_biogpt.py @@ -135,9 +135,7 @@ def create_and_check_model( result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) - def create_and_check_biogpt_model_attention_mask_past( - self, config, input_ids, input_mask, head_mask, token_type_ids, *args - ): + def create_and_check_biogpt_model_attention_mask_past(self, config, input_ids, input_mask, token_type_ids, *args): model = BioGptModel(config=config) model.to(torch_device) model.eval() @@ -177,9 +175,7 @@ def create_and_check_biogpt_model_attention_mask_past( # test that outputs are equal for slice self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) - def create_and_check_biogpt_model_past_large_inputs( - self, config, input_ids, input_mask, head_mask, token_type_ids, *args - ): + def create_and_check_biogpt_model_past_large_inputs(self, config, input_ids, input_mask, token_type_ids, *args): model = BioGptModel(config=config).to(torch_device).eval() attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) @@ -213,7 +209,7 @@ def create_and_check_biogpt_model_past_large_inputs( self.parent.assertTrue(torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3)) def create_and_check_forward_and_backwards( - self, config, input_ids, input_mask, head_mask, token_type_ids, *args, gradient_checkpointing=False + self, config, input_ids, input_mask, token_type_ids, *args, gradient_checkpointing=False ): model = BioGptForCausalLM(config) model.to(torch_device) @@ -233,9 +229,7 @@ def create_and_check_biogpt_weight_initialization(self, config, *args): self.parent.assertLessEqual(abs(torch.std(model.state_dict()[key]) - model_std), 0.001) self.parent.assertLessEqual(abs(torch.mean(model.state_dict()[key]) - 0.0), 0.01) - def create_and_check_biogpt_for_token_classification( - self, config, input_ids, input_mask, head_mask, token_type_ids, *args - ): + def create_and_check_biogpt_for_token_classification(self, config, input_ids, input_mask, token_type_ids, *args): config.num_labels = self.num_labels model = BioGptForTokenClassification(config) model.to(torch_device) @@ -326,6 +320,7 @@ def test_batch_generation(self): # Define PAD Token = EOS Token = 50256 tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id + model.generation_config.pad_token_id = model.generation_config.eos_token_id # use different length sentences to test batching sentences = [ @@ -339,10 +334,11 @@ def test_batch_generation(self): outputs = model.generate( input_ids=input_ids, attention_mask=inputs["attention_mask"].to(torch_device), + max_new_tokens=10, ) inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) - output_non_padded = model.generate(input_ids=inputs_non_padded) + output_non_padded = model.generate(input_ids=inputs_non_padded, max_new_tokens=10) num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().item() inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) diff --git a/tests/models/informer/test_modeling_informer.py b/tests/models/informer/test_modeling_informer.py index cb197631975e..92fddb499891 100644 --- a/tests/models/informer/test_modeling_informer.py +++ b/tests/models/informer/test_modeling_informer.py @@ -170,8 +170,9 @@ def check_encoder_decoder_model_standalone(self, config, inputs_dict): embed_positions = InformerSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model - ).to(torch_device) + ) embed_positions._init_weight() + embed_positions = embed_positions.to(torch_device) self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight)) self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight)) diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py index e85ef91a167f..69458160cd3b 100644 --- a/tests/models/m2m_100/test_modeling_m2m_100.py +++ b/tests/models/m2m_100/test_modeling_m2m_100.py @@ -370,7 +370,7 @@ def test_inference_no_head(self): self.assertEqual(output.shape, expected_shape) # change to expected output here expected_slice = torch.tensor( - [[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]], device=torch_device + [[[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]]], device=torch_device ) torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) @@ -387,7 +387,7 @@ def test_inference_head(self): self.assertEqual(output.shape, expected_shape) # change to expected output here expected_slice = torch.tensor( - [[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]], device=torch_device + [[[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]]], device=torch_device ) torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) diff --git a/tests/models/patchtsmixer/test_modeling_patchtsmixer.py b/tests/models/patchtsmixer/test_modeling_patchtsmixer.py index ad00eab111f6..9b48a2e30740 100644 --- a/tests/models/patchtsmixer/test_modeling_patchtsmixer.py +++ b/tests/models/patchtsmixer/test_modeling_patchtsmixer.py @@ -480,7 +480,7 @@ def test_pretrain_head(self): ) self.assertEqual(output.shape, expected_shape) - expected_slice = torch.tensor([[[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],]],device=torch_device) # fmt: skip + expected_slice = torch.tensor([[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],],device=torch_device) # fmt: skip torch.testing.assert_close(output[0, :7, :1, :1], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) def test_forecasting_head(self): diff --git a/tests/models/pegasus_x/test_modeling_pegasus_x.py b/tests/models/pegasus_x/test_modeling_pegasus_x.py index c1862d9301ee..e11017f6f19c 100644 --- a/tests/models/pegasus_x/test_modeling_pegasus_x.py +++ b/tests/models/pegasus_x/test_modeling_pegasus_x.py @@ -590,7 +590,7 @@ def test_inference_no_head(self): self.assertEqual(output.shape, expected_shape) # change to expected output here expected_slice = torch.tensor( - [[0.0702, -0.1552, 0.1192], [0.0836, -0.1848, 0.1304], [0.0673, -0.1686, 0.1045]], device=torch_device + [[[0.0702, -0.1552, 0.1192], [0.0836, -0.1848, 0.1304], [0.0673, -0.1686, 0.1045]]], device=torch_device ) torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) @@ -608,7 +608,7 @@ def test_inference_head(self): self.assertEqual(output.shape, expected_shape) # change to expected output here expected_slice = torch.tensor( - [[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]], device=torch_device + [[[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]]], device=torch_device ) torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) @@ -635,8 +635,7 @@ def test_seq_to_seq_generation(self): batch_input, max_length=512, padding="max_length", - truncation_strategy="only_first", - truncation=True, + truncation="only_first", return_tensors="pt", ) From cc6cae0da714b0cac7e756bc6aa1febaa70c0eed Mon Sep 17 00:00:00 2001 From: Vasqu Date: Fri, 16 May 2025 17:37:57 +0200 Subject: [PATCH 46/68] style --- tests/models/m2m_100/test_modeling_m2m_100.py | 6 ++++-- tests/models/pegasus_x/test_modeling_pegasus_x.py | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py index 69458160cd3b..642a93dad1c5 100644 --- a/tests/models/m2m_100/test_modeling_m2m_100.py +++ b/tests/models/m2m_100/test_modeling_m2m_100.py @@ -370,7 +370,8 @@ def test_inference_no_head(self): self.assertEqual(output.shape, expected_shape) # change to expected output here expected_slice = torch.tensor( - [[[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]]], device=torch_device + [[[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]]], + device=torch_device, ) torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) @@ -387,7 +388,8 @@ def test_inference_head(self): self.assertEqual(output.shape, expected_shape) # change to expected output here expected_slice = torch.tensor( - [[[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]]], device=torch_device + [[[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]]], + device=torch_device, ) torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) diff --git a/tests/models/pegasus_x/test_modeling_pegasus_x.py b/tests/models/pegasus_x/test_modeling_pegasus_x.py index e11017f6f19c..5a63cc56181f 100644 --- a/tests/models/pegasus_x/test_modeling_pegasus_x.py +++ b/tests/models/pegasus_x/test_modeling_pegasus_x.py @@ -608,7 +608,8 @@ def test_inference_head(self): self.assertEqual(output.shape, expected_shape) # change to expected output here expected_slice = torch.tensor( - [[[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]]], device=torch_device + [[[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]]], + device=torch_device, ) torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE) From 6369055e2911da9acf9e89308c4c6d88bd161b47 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Mon, 19 May 2025 16:10:39 +0200 Subject: [PATCH 47/68] current status --- .../generation/candidate_generator.py | 8 +- .../models/autoformer/modeling_autoformer.py | 2 +- .../models/big_bird/modeling_big_bird.py | 1 - .../modeling_bigbird_pegasus.py | 420 +++++++++++--- .../models/biogpt/modular_biogpt.py | 350 ++++++++---- .../models/blenderbot/modeling_blenderbot.py | 98 ++-- .../models/marian/modeling_marian.py | 519 +++++++++++------ .../models/mbart/modeling_mbart.py | 56 +- .../models/nllb_moe/modeling_nllb_moe.py | 2 +- .../patchtsmixer/modeling_patchtsmixer.py | 2 +- .../models/patchtst/modeling_patchtst.py | 2 +- .../models/plbart/modeling_plbart.py | 538 +++++++++++------- .../models/plbart/modular_plbart.py | 223 +++++++- src/transformers/models/sew/modular_sew.py | 4 +- .../speech_to_text/modeling_speech_to_text.py | 6 +- tests/generation/test_utils.py | 1 + .../test_modeling_bigbird_pegasus.py | 4 +- tests/models/marian/test_modeling_marian.py | 10 +- tests/models/mbart/test_modeling_mbart.py | 8 +- tests/models/plbart/test_modeling_plbart.py | 8 +- tests/test_modeling_common.py | 11 +- 21 files changed, 1588 insertions(+), 685 deletions(-) diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 3425a0234b42..bb9222030553 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -28,7 +28,7 @@ if is_sklearn_available(): from sklearn.metrics import roc_curve -from ..cache_utils import DynamicCache +from ..cache_utils import Cache from ..pytorch_utils import isin_mps_friendly from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor @@ -1183,7 +1183,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, def _crop_past_key_values(model, past_key_values, max_length): """Crops the past key values up to a certain maximum length.""" new_past = [] - if model.config.is_encoder_decoder: + if isinstance(past_key_values, Cache): + past_key_values.crop(max_length) + elif model.config.is_encoder_decoder: for idx in range(len(past_key_values)): new_past.append( ( @@ -1204,8 +1206,6 @@ def _crop_past_key_values(model, past_key_values, max_length): else: for idx in range(len(past_key_values)): past_key_values[idx] = past_key_values[idx][:, :, :max_length, :] - elif isinstance(past_key_values, DynamicCache): - past_key_values.crop(max_length) elif past_key_values is not None: for idx in range(len(past_key_values)): if past_key_values[idx] != ([], []): diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index c9aae04f4099..ed6cf2d5473b 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -1048,7 +1048,7 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) - # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 5bae8894716e..2106c07e7dfd 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1340,7 +1340,6 @@ def set_attention_type(self, value: str): attn_weights.value = self.self.value attn_weights.key = self.self.key self.self = attn_weights - self.attention_type = value if not self.training: self.self.eval() diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 4e5e7e089e4d..f21506121a46 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -24,8 +24,13 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( BaseModelOutput, @@ -38,10 +43,14 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, logging +from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging from .configuration_bigbird_pegasus import BigBirdPegasusConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _EXPECTED_OUTPUT_SHAPE = [1, 7, 1024] @@ -71,13 +80,15 @@ class BigBirdPegasusLearnedPositionalEmbedding(nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int): super().__init__(num_embeddings, embedding_dim) - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: torch.Tensor = None): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BigBirdPegasus @@ -1116,7 +1127,6 @@ def set_attention_type(self, value: str): if value == self.attention_type: return - self.attention_type = value if value == "original_full": # copy all weights to new full attention class attn_weights = BigBirdPegasusSelfAttention(self.config) @@ -1138,7 +1148,6 @@ def forward( hidden_states, attention_mask=None, head_mask=None, - past_key_value=None, output_attentions=False, band_mask=None, from_mask=None, @@ -1149,12 +1158,11 @@ def forward( # Expand dims to enable multiplication in the self-attention module head_mask = head_mask.reshape(1, -1, 1, 1) if head_mask is not None else None - if self.config.attention_type == "original_full": + if self.attention_type == "original_full": self_outputs = self.self( hidden_states, attention_mask, head_mask, - past_key_value=past_key_value, output_attentions=output_attentions, ) else: @@ -1208,6 +1216,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[BigBirdPegasusConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -1224,6 +1233,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -1234,10 +1250,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -1258,42 +1275,37 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attn_forward attention_type = self.config._attn_implementation @@ -1420,7 +1432,7 @@ def set_attention_type(self, value: str): class BigBirdPegasusDecoderLayer(nn.Module): - def __init__(self, config: BigBirdPegasusConfig): + def __init__(self, config: BigBirdPegasusConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model self.self_attn = BigBirdPegasusDecoderAttention( @@ -1430,6 +1442,7 @@ def __init__(self, config: BigBirdPegasusConfig): is_decoder=True, bias=config.use_bias, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -1443,6 +1456,7 @@ def __init__(self, config: BigBirdPegasusConfig): is_decoder=True, bias=config.use_bias, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -1461,6 +1475,7 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.LongTensor] = None, ) -> torch.Tensor: """ Args: @@ -1484,42 +1499,35 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -1535,7 +1543,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -1573,6 +1581,8 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): _no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_param_buffer_assignment = False + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -1584,6 +1594,9 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() @property def dummy_inputs(self): @@ -1595,6 +1608,185 @@ def dummy_inputs(self): } return dummy_inputs + # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + _unsupported_features: bool = False, + dropout: float = 0.0, + ): + if self.config._attn_implementation == "flex_attention" and not _unsupported_features and (dropout == 0 or not self.training): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not _unsupported_features + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + dropout: float = 0.0, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + + class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): """ @@ -1924,7 +2116,7 @@ def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embed config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([BigBirdPegasusDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([BigBirdPegasusDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1951,6 +2143,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ): r""" Args: @@ -2016,6 +2209,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -2024,8 +2220,15 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() @@ -2035,43 +2238,69 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + attention_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + _unsupported_features, + self.config.attention_dropout, + ) + encoder_attention_mask = self._update_cross_attn_mask( + encoder_hidden_states, + encoder_attention_mask, + input_shape, + inputs_embeds, + _unsupported_features, + self.config.attention_dropout, + ) + # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) positions = positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -2090,8 +2319,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -2104,6 +2331,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -2115,14 +2343,15 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -2137,6 +2366,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -2208,6 +2440,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -2279,6 +2512,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -2369,6 +2603,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -2442,6 +2677,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) @@ -2524,6 +2760,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -2569,6 +2806,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] # last hidden state @@ -2656,6 +2894,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -2693,6 +2932,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] @@ -2804,6 +3044,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -2852,6 +3093,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 351b072fc68e..8e1c14703fff 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -24,10 +24,10 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, + AttentionMaskConverter, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -42,6 +42,7 @@ LossKwargs, auto_docstring, is_torch_flex_attn_available, + is_torchdynamo_compiling, logger, ) from ..bart.modeling_bart import ( @@ -49,35 +50,23 @@ BartDecoderLayer, BartScaledWordEmbedding, ) +from ..opt.modeling_opt import OPTLearnedPositionalEmbedding from .configuration_biogpt import BioGptConfig if is_torch_flex_attn_available(): - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask -class BioGptLearnedPositionalEmbedding(nn.Embedding): - """ - This module learns positional embeddings up to a fixed maximum size. - """ - - def __init__(self, num_embeddings: int, embedding_dim: int): - # BioGpt is set up so that if padding_idx is specified then offset the embedding ids by 2 - # and adjust num_embeddings appropriately. Other models don't have this hack - self.offset = 2 - super().__init__(num_embeddings + self.offset, embedding_dim) - - def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): +class BioGptLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding): + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.LongTensor] = None, + ): """`input_ids_shape` is expected to be [bsz x seqlen].""" - attention_mask = attention_mask.long() - - # create positions depending on attention_mask - positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 - - # cut positions if `past_key_values_length` is > 0 - positions = positions[:, past_key_values_length:] - - return super().forward(positions + self.offset) + super().forward(attention_mask, past_key_values_length, position_ids) class BioGptScaledWordEmbedding(BartScaledWordEmbedding): @@ -89,7 +78,7 @@ class BioGptAttention(BartAttention): class BioGptDecoderLayer(BartDecoderLayer): - def __init__(self, config: BioGptConfig): + def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None): super().__init__(config) self.embed_dim = config.hidden_size @@ -100,6 +89,7 @@ def __init__(self, config: BioGptConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.hidden_dropout_prob self.activation_fn = ACT2FN[config.hidden_act] @@ -115,9 +105,11 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + position_ids: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ @@ -134,21 +126,23 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + position_ids=position_ids, + cache_position=cache_position, **flash_attn_kwargs, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -170,7 +164,7 @@ def forward( outputs += (self_attn_weights,) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -182,7 +176,10 @@ class BioGptPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # compiling issues + _supports_flex_attn = False + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): """Initialize the weights""" @@ -200,6 +197,142 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + _unsupported_features: bool = False, + dropout: float = 0.0, + ): + if self.config._attn_implementation == "flex_attention" and not _unsupported_features and (dropout == 0 or not self.training): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not _unsupported_features + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + @auto_docstring class BioGptModel(BioGptPreTrainedModel): @@ -217,7 +350,7 @@ def __init__(self, config: BioGptConfig): ) self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim) - self.layers = nn.ModuleList([BioGptDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.layer_norm = nn.LayerNorm(self.embed_dim) self.gradient_checkpointing = False @@ -239,9 +372,11 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **flash_attn_kwargs: Unpack[FlashAttentionKwargs], ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -252,49 +387,75 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input = input_ids - input_shape = input.size() + input_shape = input.shape + input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] input = inputs_embeds[:, :, -1] else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) - if attention_mask is None: - attention_mask = torch.ones( - (inputs_embeds.shape[0], inputs_embeds.shape[1] + past_key_values_length), - dtype=torch.bool, - device=inputs_embeds.device, + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize past_key_values + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - elif attention_mask.shape[1] != past_key_values_length + input_shape[1]: - raise ValueError( - f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " - f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)" + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # embed positions - positions = self.embed_positions(attention_mask, past_key_values_length) + if attention_mask is None: + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) - _unsupported_features = output_attentions is True or head_mask is not None - attention_mask = self._update_causal_mask( + causal_mask = self._update_causal_mask( attention_mask, - input_shape, inputs_embeds, - past_key_values_length, - _unsupported_features, + cache_position, + self_attn_cache, + output_attentions, ) + # embed positions + if position_ids is None: + # position_ids = cache_position.unsqueeze(0) + position_ids = torch.cumsum(attention_mask, dim=1) + position_ids = (position_ids * attention_mask - 1).long() + # cut positions if `past_seen_tokens` is > 0 + position_ids = position_ids[:, past_key_values_length:] + + positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) hidden_states = inputs_embeds + positions - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) if self.gradient_checkpointing and self.training: @@ -318,33 +479,35 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( partial(decoder_layer.__call__, **flash_attn_kwargs), hidden_states, - attention_mask, + causal_mask, head_mask[idx] if head_mask is not None else None, None, output_attentions, use_cache, + position_ids, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + position_ids=position_ids, + cache_position=cache_position, **flash_attn_kwargs, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -356,6 +519,8 @@ def forward( hidden_states = self.layer_norm(hidden_states) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() if not return_dict: return tuple( @@ -371,51 +536,6 @@ def forward( cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask with attention_dropout->attention_probs_dropout_prob - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, - _unsupported_features: bool, - ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_probs_dropout_prob == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - # Other attention flavors support in-built causal (when `mask is None`) - # while we need to create our specific block mask regardless - elif attention_mask is None: - attention_mask = make_flex_block_causal_mask( - torch.ones( - size=(input_shape), - device=inputs_embeds.device, - ) - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - return attention_mask - class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ... @@ -453,9 +573,11 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs: Unpack[KwargsForCausalLM], ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -473,9 +595,11 @@ def forward( inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, + position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, **kwargs, ) @@ -541,9 +665,11 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, TokenClassifierOutput]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -560,9 +686,11 @@ def forward( head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, + position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] @@ -629,9 +757,11 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, SequenceClassifierOutputWithPast]: r""" labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): @@ -648,9 +778,11 @@ def forward( head_mask=head_mask, inputs_embeds=inputs_embeds, use_cache=use_cache, + position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = transformer_outputs[0] logits = self.score(hidden_states) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 7c293ab1b290..57b6334d3720 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -26,8 +26,10 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask, @@ -43,14 +45,13 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging from ..blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel from .configuration_blenderbot import BlenderbotConfig if is_torch_flex_attn_available(): - from ...integrations.flex_attention import make_flex_block_causal_mask - + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -80,13 +81,16 @@ class BlenderbotLearnedPositionalEmbedding(nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int): super().__init__(num_embeddings, embedding_dim) - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ): """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->Blenderbot @@ -143,7 +147,8 @@ def __init__( is_decoder: bool = False, bias: bool = True, is_causal: bool = False, - config: Optional[BlenderbotConfig] = None, + config: Optional[BartConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -160,6 +165,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -170,10 +182,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -194,42 +207,37 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attn_forward attention_type = self.config._attn_implementation diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 0ddafe75287c..31eb58ab1feb 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -25,8 +25,10 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask, @@ -42,12 +44,12 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging from .configuration_marian import MarianConfig if is_torch_flex_attn_available(): - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -142,6 +144,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[MarianConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -158,6 +161,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -168,10 +178,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -192,42 +203,37 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attn_forward attention_type = self.config._attn_implementation @@ -268,7 +274,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN class MarianEncoderLayer(nn.Module): - def __init__(self, config: MarianConfig): + def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -277,6 +283,7 @@ def __init__(self, config: MarianConfig): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, + layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -339,7 +346,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN class MarianDecoderLayer(nn.Module): - def __init__(self, config: MarianConfig): + def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -350,6 +357,7 @@ def __init__(self, config: MarianConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -362,6 +370,7 @@ def __init__(self, config: MarianConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -376,9 +385,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -397,16 +407,16 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -416,28 +426,22 @@ def forward( hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -453,7 +457,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -466,6 +470,8 @@ class MarianPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]): std = self.config.init_std @@ -479,6 +485,9 @@ def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalP module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() @property def dummy_inputs(self): @@ -491,6 +500,212 @@ def dummy_inputs(self): } return dummy_inputs + # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + _unsupported_features: bool = False, + dropout: float = 0.0, + ): + if self.config._attn_implementation == "flex_attention" and not _unsupported_features and (dropout == 0 or not self.training): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not _unsupported_features + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + dropout: float = 0.0, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class MarianEncoder(MarianPreTrainedModel): """ @@ -521,7 +736,7 @@ def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = self.embed_positions = MarianSinusoidalPositionalEmbedding( config.max_position_embeddings, embed_dim, self.padding_idx ) - self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layers = nn.ModuleList([MarianEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -663,34 +878,6 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) - # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask - def _update_full_mask( - self, - attention_mask: Union[torch.Tensor, None], - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) - - return attention_mask - class MarianDecoder(MarianPreTrainedModel): """ @@ -717,7 +904,7 @@ def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = self.embed_positions = MarianSinusoidalPositionalEmbedding( config.max_position_embeddings, config.d_model, self.padding_idx ) - self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([MarianDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -743,6 +930,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -808,6 +996,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -816,8 +1007,15 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input_shape = input_ids.size() @@ -827,19 +1025,47 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + # TODO: update mask creation with new interface _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) attention_mask = self._update_causal_mask( attention_mask, - input_shape, inputs_embeds, - past_key_values_length, + cache_position, + self_attn_cache, _unsupported_features, + self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, @@ -847,6 +1073,7 @@ def forward( input_shape, inputs_embeds, _unsupported_features, + self.config.attention_dropout, ) # embed positions @@ -856,18 +1083,11 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -885,8 +1105,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -899,6 +1117,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -910,14 +1129,15 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -930,6 +1150,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -944,92 +1167,6 @@ def forward( cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, - _unsupported_features: bool, - ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - # Other attention flavors support in-built causal (when `mask is None`) - # while we need to create our specific block mask regardless - elif attention_mask is None: - attention_mask = make_flex_block_causal_mask( - torch.ones( - size=(input_shape), - device=inputs_embeds.device, - ) - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - return attention_mask - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask - def _update_cross_attn_mask( - self, - encoder_hidden_states: Union[torch.Tensor, None], - encoder_attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(encoder_attention_mask, torch.Tensor): - encoder_attention_mask = make_flex_block_causal_mask( - encoder_attention_mask, - query_length=input_shape[-1], - is_causal=False, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - - return encoder_attention_mask - @auto_docstring class MarianModel(MarianPreTrainedModel): @@ -1134,6 +1271,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Seq2SeqModelOutput: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1216,6 +1354,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1390,6 +1529,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Seq2SeqLMOutput: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1464,6 +1604,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias @@ -1568,6 +1709,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1617,6 +1759,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index e9deda6d7488..5b5e0ce6ae92 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -30,10 +30,10 @@ AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) -from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_flash_attention_utils import ( + FlashAttentionKwargs, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -45,7 +45,12 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_mbart import MBartConfig @@ -286,7 +291,7 @@ def forward( class MBartEncoderLayer(nn.Module): - def __init__(self, config: MBartConfig, layer_idx: Optional[int] = None): + def __init__(self, config: MBartConfig): super().__init__() self.embed_dim = config.d_model @@ -295,7 +300,6 @@ def __init__(self, config: MBartConfig, layer_idx: Optional[int] = None): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, - layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -448,7 +452,6 @@ def forward( layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, - cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -742,7 +745,6 @@ def _update_cross_attn_mask( return encoder_attention_mask - class MBartEncoder(MBartPreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a @@ -775,7 +777,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N config.max_position_embeddings, embed_dim, ) - self.layers = nn.ModuleList([MBartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]) + self.layers = nn.ModuleList([MBartEncoderLayer(config) for _ in range(config.encoder_layers)]) self.config = config self.layernorm_embedding = nn.LayerNorm(embed_dim) self.layer_norm = nn.LayerNorm(config.d_model) @@ -953,7 +955,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([MBartDecoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]) + self.layers = nn.ModuleList([MBartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.config = config self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -983,7 +985,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -1060,19 +1062,12 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # retrieve input_ids and inputs_embeds if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input = input_ids - input_shape = input.size() + input_shape = input.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] @@ -1081,7 +1076,14 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False # initialize `past_key_values` return_legacy_cache = False @@ -1113,7 +1115,7 @@ def forward( ) _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None - attention_mask = self._update_causal_mask( + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, @@ -1131,9 +1133,9 @@ def forward( ) # embed positions - positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) + position_ids = self.embed_positions(input, past_key_values_length, position_ids=cache_position) - hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + hidden_states = inputs_embeds + position_ids.to(inputs_embeds.device) hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1165,7 +1167,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1178,7 +1180,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), @@ -1280,7 +1282,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Seq2SeqModelOutput, Tuple[torch.FloatTensor]]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1437,7 +1439,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 452c89ff969d..e46519649f59 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -499,7 +499,7 @@ def eager_attn_forward( return attn_output, attn_weights -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->NllbMoe,key_value_states->encoder_hidden_states +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->NllbMoe,key_value_states->encoder_hidden_states class NllbMoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index ee243a6c79a7..a8d82fac19bb 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -265,7 +265,7 @@ def eager_attn_forward( return attn_output, attn_weights -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTSMixer +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->PatchTSMixer class PatchTSMixerAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 807c56e0833d..f302381f776e 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -62,7 +62,7 @@ def eager_attn_forward( return attn_output, attn_weights -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTST +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->PatchTST class PatchTSTAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 96f90eeab3ce..ed80516e3a8c 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -28,12 +28,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -46,12 +46,12 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging from .configuration_plbart import PLBartConfig if is_torch_flex_attn_available(): - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -91,6 +91,216 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + _unsupported_features: bool = False, + dropout: float = 0.0, + ): + if ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not _unsupported_features + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + dropout: float = 0.0, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class PLBartLearnedPositionalEmbedding(nn.Embedding): """ @@ -103,15 +313,18 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor = None): """`input_ids' shape is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids.shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ).expand(bsz, -1) + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + else: + position_ids = position_ids.unsqueeze(0) - return super().forward(positions + self.offset) + return super().forward(position_ids + self.offset) def eager_attn_forward( @@ -153,6 +366,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[PLBartConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -169,6 +383,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -179,10 +400,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -203,42 +425,37 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attn_forward attention_type = self.config._attn_implementation @@ -278,7 +495,7 @@ def forward( class PLBartEncoderLayer(nn.Module): - def __init__(self, config: PLBartConfig): + def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -287,6 +504,7 @@ def __init__(self, config: PLBartConfig): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, + layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -379,7 +597,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = config.max_position_embeddings, embed_dim, ) - self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layers = nn.ModuleList([PLBartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]) self.layernorm_embedding = nn.LayerNorm(embed_dim) self.gradient_checkpointing = False @@ -526,36 +744,9 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) - def _update_full_mask( - self, - attention_mask: Union[torch.Tensor, None], - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) - - return attention_mask - class PLBartDecoderLayer(nn.Module): - def __init__(self, config: PLBartConfig): + def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -566,6 +757,7 @@ def __init__(self, config: PLBartConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -578,6 +770,7 @@ def __init__(self, config: PLBartConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -592,9 +785,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -613,47 +807,43 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -669,7 +859,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -702,7 +892,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([PLBartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -730,6 +920,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -795,6 +986,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -803,8 +997,15 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: input = input_ids @@ -816,19 +1017,47 @@ def forward( else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input) + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + # TODO: update mask creation with new interface _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) attention_mask = self._update_causal_mask( attention_mask, - input_shape, inputs_embeds, - past_key_values_length, + cache_position, + self_attn_cache, _unsupported_features, + self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, @@ -836,10 +1065,11 @@ def forward( input_shape, inputs_embeds, _unsupported_features, + self.config.attention_dropout, ) # embed positions - positions = self.embed_positions(input, past_key_values_length) + positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) positions = positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions @@ -847,18 +1077,11 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -878,8 +1101,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -892,6 +1113,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -903,14 +1125,15 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -923,6 +1146,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -937,90 +1163,6 @@ def forward( cross_attentions=all_cross_attentions, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, - _unsupported_features: bool, - ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - # Other attention flavors support in-built causal (when `mask is None`) - # while we need to create our specific block mask regardless - elif attention_mask is None: - attention_mask = make_flex_block_causal_mask( - torch.ones( - size=(input_shape), - device=inputs_embeds.device, - ) - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - return attention_mask - - def _update_cross_attn_mask( - self, - encoder_hidden_states: Union[torch.Tensor, None], - encoder_attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(encoder_attention_mask, torch.Tensor): - encoder_attention_mask = make_flex_block_causal_mask( - encoder_attention_mask, - query_length=input_shape[-1], - is_causal=False, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - - return encoder_attention_mask - def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int): """ @@ -1095,6 +1237,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1169,6 +1312,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1251,6 +1395,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1328,6 +1473,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) @@ -1432,6 +1578,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1489,6 +1636,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] # last hidden state @@ -1609,6 +1757,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1658,6 +1807,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index 0d0b9fd644da..1bc4326c77ec 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -22,14 +22,20 @@ from torch import nn from torch.nn import CrossEntropyLoss +from ...cache_utils import Cache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutput, Seq2SeqLMOutput, Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring +from ...utils import auto_docstring, is_torch_flex_attn_available from ..bart.modeling_bart import ( BartClassificationHead, BartDecoder, @@ -42,6 +48,10 @@ from .configuration_plbart import PLBartConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + + class PLBartScaledWordEmbedding(BartScaledWordEmbedding): pass @@ -67,6 +77,213 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + + # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + _unsupported_features: bool = False, + dropout: float = 0.0, + ): + if self.config._attn_implementation == "flex_attention" and not _unsupported_features and (dropout == 0 or not self.training): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not _unsupported_features + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + dropout: float = 0.0, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class PLBartEncoder(BartEncoder): pass @@ -129,6 +346,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -203,6 +421,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -285,6 +504,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -362,6 +582,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py index 1f8b236e8760..b3aa3e01b6cd 100644 --- a/src/transformers/models/sew/modular_sew.py +++ b/src/transformers/models/sew/modular_sew.py @@ -28,8 +28,8 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring -from ..bart.modeling_bart import BartAttention from ..wav2vec2.modeling_wav2vec2 import ( + Wav2Vec2Attention, Wav2Vec2EncoderLayer, Wav2Vec2FeatureEncoder, Wav2Vec2FeedForward, @@ -143,7 +143,7 @@ def __init__(self, config): ) -class SEWAttention(BartAttention): +class SEWAttention(Wav2Vec2Attention): pass diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 5a90308198ef..3b173c35cd44 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -201,7 +201,7 @@ def eager_attn_forward( return attn_output, attn_weights -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Speech2Text +# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Speech2Text class Speech2TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -407,7 +407,8 @@ def forward( return outputs -# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT +# copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT +# TODO: change copy when applying cache class class Speech2TextDecoderLayer(nn.Module): def __init__(self, config: Speech2TextConfig): super().__init__() @@ -438,6 +439,7 @@ def __init__(self, config: Speech2TextConfig): self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer.forward def forward( self, hidden_states: torch.Tensor, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index bd673d796a62..f0dccff08110 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2149,6 +2149,7 @@ def test_generate_compile_model_forward(self): "return_dict_in_generate": True, "output_scores": True, "compile_config": compile_config, + "use_cache": True, } # 4. get eager + dynamic cache results for future comparison diff --git a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py index ecdde66786cf..34e633757acd 100644 --- a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py +++ b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py @@ -611,7 +611,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -767,7 +767,7 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, use_cache=True)["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index ec013467432a..ed42b1b29f00 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -99,7 +99,7 @@ def __init__( hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=20, + max_position_embeddings=100, eos_token_id=2, pad_token_id=1, bos_token_id=0, @@ -653,7 +653,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=100, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -796,9 +796,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py index 521f27319df7..4ef22c3c30e0 100644 --- a/tests/models/mbart/test_modeling_mbart.py +++ b/tests/models/mbart/test_modeling_mbart.py @@ -538,7 +538,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -681,9 +681,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py index 303e5b734531..179750582832 100644 --- a/tests/models/plbart/test_modeling_plbart.py +++ b/tests/models/plbart/test_modeling_plbart.py @@ -496,7 +496,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -634,9 +634,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 0610342ba08c..429204551be8 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4506,8 +4506,15 @@ def test_flex_attention_with_grads(self): # TODO: raushan, fix for composite models after making VLMs support new attn API if not model_class._supports_flex_attn or self._is_composite: self.skipTest(reason="This model does not support flex attention") + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config._attn_implementation = "flex_attention" + # Flex Attention can not use dropout + if hasattr(config, "attention_droput"): + config.attention_droput = 0 + if hasattr(config, "attention_probs_dropout_prob"): + config.attention_probs_dropout_prob = 0 + model = model_class(config).to(device=torch_device) self.assertTrue(model.config._attn_implementation == "flex_attention") @@ -4519,10 +4526,6 @@ def test_flex_attention_with_grads(self): else: dummy_input = {model_class.main_input_name: inputs_dict[model_class.main_input_name].to(torch_device)} - # Flex Attention can not use dropout - if hasattr(config, "attention_droput"): - config.attention_droput = 0 - # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) _ = model(**dummy_input) From 66c93c146481b0f50f2b62792c7b738f48bfe0e1 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Mon, 19 May 2025 16:38:40 +0200 Subject: [PATCH 48/68] marian works now --- .../models/marian/modeling_marian.py | 67 ++++++++++++------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 31eb58ab1feb..a2d9138ab34b 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -31,8 +31,6 @@ AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -44,12 +42,19 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_marian import MarianConfig if is_torch_flex_attn_available(): - from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -94,13 +99,16 @@ def _init_weight(self): self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.bart.modeling_bart.eager_attn_forward @@ -420,6 +428,7 @@ def forward( attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -437,6 +446,7 @@ def forward( layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -736,7 +746,7 @@ def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = self.embed_positions = MarianSinusoidalPositionalEmbedding( config.max_position_embeddings, embed_dim, self.padding_idx ) - self.layers = nn.ModuleList([MarianEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]) + self.layers = nn.ModuleList([MarianEncoderLayer(config) for _ in range(config.encoder_layers)]) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -930,7 +940,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -1018,15 +1028,20 @@ def forward( if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() + input = input_ids + input_shape = input.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input) + + # Important to apply outside of the above `if`, in case user passes `embeds` + inputs_embeds = inputs_embeds * self.embed_scale # initialize `past_key_values` return_legacy_cache = False @@ -1046,9 +1061,6 @@ def forward( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # TODO: update mask creation with new interface - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None - if attention_mask is None and not is_torchdynamo_compiling(): # required mask seq length can be calculated via length of past cache mask_seq_length = past_key_values_length + seq_length @@ -1059,7 +1071,10 @@ def forward( if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) - attention_mask = self._update_causal_mask( + + # TODO: update mask creation with new interface + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None + causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, @@ -1077,10 +1092,10 @@ def forward( ) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) - - hidden_states = inputs_embeds + positions - + position_ids = self.embed_positions( + (batch_size, seq_length), past_key_values_length, position_ids=cache_position + ) + hidden_states = inputs_embeds + position_ids hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) # decoder layers @@ -1109,7 +1124,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1122,7 +1137,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), @@ -1271,7 +1286,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Seq2SeqModelOutput: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1529,7 +1544,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Seq2SeqLMOutput: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): From f8368cf14a10a7a52b19a5a8fa9b4eb18476d014 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Mon, 19 May 2025 16:54:09 +0200 Subject: [PATCH 49/68] fixing some copies --- .../models/autoformer/modeling_autoformer.py | 15 +++++++++------ src/transformers/models/bart/modeling_bart.py | 7 +++---- .../bigbird_pegasus/modeling_bigbird_pegasus.py | 8 +++++--- .../models/musicgen/modeling_musicgen.py | 1 - src/transformers/models/mvp/modeling_mvp.py | 17 ++++++++++------- .../models/nllb_moe/modeling_nllb_moe.py | 6 +++--- .../models/plbart/modeling_plbart.py | 6 +++--- .../models/roformer/modeling_roformer.py | 15 +++++++++------ .../seamless_m4t/modeling_seamless_m4t.py | 8 ++++++++ .../seamless_m4t_v2/modeling_seamless_m4t_v2.py | 8 ++++++++ .../speech_to_text/modeling_speech_to_text.py | 6 +++--- src/transformers/models/trocr/modeling_trocr.py | 15 +++++++++------ src/transformers/models/xglm/modeling_xglm.py | 2 +- 13 files changed, 71 insertions(+), 43 deletions(-) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index ed6cf2d5473b..d59ef544fde1 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -377,13 +377,16 @@ def _init_weight(self): self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesValueEmbedding with TimeSeries->Autoformer diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index cbdfaf11ea9c..e6c5fe6ae1ae 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -695,7 +695,6 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask def _update_cross_attn_mask( self, encoder_hidden_states: Union[torch.Tensor, None], @@ -1108,9 +1107,6 @@ def forward( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # TODO: update mask creation with new interface - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None - if attention_mask is None and not is_torchdynamo_compiling(): # required mask seq length can be calculated via length of past cache mask_seq_length = past_key_values_length + seq_length @@ -1121,6 +1117,9 @@ def forward( if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) + + # TODO: update mask creation with new interface + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index f21506121a46..f33fe73ccfce 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1472,10 +1472,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, - cache_position: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -1494,6 +1494,9 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) @@ -1523,7 +1526,6 @@ def forward( layer_head_mask=cross_attn_layer_head_mask, past_key_value=past_key_value, output_attentions=output_attentions, - cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 32cfd5a7dc31..19a4001da8bc 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -737,7 +737,6 @@ def _update_causal_mask( return attention_mask - # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask def _update_cross_attn_mask( self, encoder_hidden_states: Union[torch.Tensor, None], diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 1c3862349681..1f66bccb9a13 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -61,7 +61,7 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start # Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MVP -class MvpLearnedPositionalEmbedding(nn.Embedding): +class MVPLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. """ @@ -72,15 +72,18 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor = None): """`input_ids' shape is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids.shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ).expand(bsz, -1) + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + else: + position_ids = position_ids.unsqueeze(0) - return super().forward(positions + self.offset) + return super().forward(position_ids + self.offset) class MvpAttention(nn.Module): diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index e46519649f59..87fa1328f540 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -1073,7 +1073,7 @@ def forward( router_probs=all_router_probs, ) - # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], @@ -1392,7 +1392,7 @@ def forward( router_probs=all_router_probs, ) - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, None], @@ -1437,7 +1437,7 @@ def _update_causal_mask( return attention_mask - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask def _update_cross_attn_mask( self, encoder_hidden_states: Union[torch.Tensor, None], diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index ed80516e3a8c..a8bb8d1c5f7d 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1038,9 +1038,6 @@ def forward( past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # TODO: update mask creation with new interface - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None - if attention_mask is None and not is_torchdynamo_compiling(): # required mask seq length can be calculated via length of past cache mask_seq_length = past_key_values_length + seq_length @@ -1051,6 +1048,9 @@ def forward( if isinstance(past_key_values, EncoderDecoderCache) else past_key_values ) + + # TODO: update mask creation with new interface + _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 7e4f1ff9f28b..1fdccf728463 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -67,13 +67,16 @@ def _init_weight(self): self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) def load_tf_weights_in_roformer(model, config, tf_checkpoint_path): diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index d2feaafa8581..88ffc871b2c8 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -1003,6 +1003,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[SeamlessM4TConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -1019,6 +1020,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 69dc83296731..c62b97fb89ce 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -909,6 +909,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[SeamlessM4Tv2Config] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -925,6 +926,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 3b173c35cd44..010bdf165420 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -737,7 +737,7 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) - # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], @@ -1020,7 +1020,7 @@ def forward( cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask def _update_causal_mask( self, attention_mask: Union[torch.Tensor, None], @@ -1065,7 +1065,7 @@ def _update_causal_mask( return attention_mask - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask def _update_cross_attn_mask( self, encoder_hidden_states: Union[torch.Tensor, None], diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index e2c77f577991..97638e929edf 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -46,15 +46,18 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor = None): """`input_ids' shape is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids.shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ).expand(bsz, -1) + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + else: + position_ids = position_ids.unsqueeze(0) - return super().forward(positions + self.offset) + return super().forward(position_ids + self.offset) # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->TrOCR diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 1b9c14ddcaf1..e9f3f63c965a 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -282,7 +282,7 @@ def __init__(self, config: XGLMConfig): self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer.forward def forward( self, hidden_states: torch.Tensor, From f34d11db36274c7389555a2e4be55ce5be8a7613 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Mon, 19 May 2025 18:06:03 +0200 Subject: [PATCH 50/68] some copy fixes + time series x informer --- .../models/autoformer/modeling_autoformer.py | 56 +-- .../modeling_bigbird_pegasus.py | 4 +- .../models/biogpt/modeling_biogpt.py | 2 +- .../models/biogpt/modular_biogpt.py | 3 +- .../models/informer/modeling_informer.py | 456 ++++++++++-------- .../models/informer/modular_informer.py | 202 ++++++-- .../models/m2m_100/modeling_m2m_100.py | 8 +- .../models/marian/modeling_marian.py | 4 +- .../models/mbart/modeling_mbart.py | 6 +- .../models/plbart/modeling_plbart.py | 6 +- .../models/plbart/modular_plbart.py | 6 +- .../modeling_time_series_transformer.py | 400 ++++++++------- 12 files changed, 672 insertions(+), 481 deletions(-) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index d59ef544fde1..afafd5dd3615 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -911,6 +911,34 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoder with TimeSeriesTransformer->Autoformer,TimeSeries->Autoformer class AutoformerEncoder(AutoformerPreTrainedModel): @@ -1051,34 +1079,6 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) - # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask - def _update_full_mask( - self, - attention_mask: Union[torch.Tensor, None], - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) - - return attention_mask - class AutoformerDecoder(AutoformerPreTrainedModel): """ diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index f33fe73ccfce..fab94c336a23 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1610,7 +1610,7 @@ def dummy_inputs(self): } return dummy_inputs - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], @@ -1746,7 +1746,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask def _update_cross_attn_mask( self, encoder_hidden_states: Union[torch.Tensor, None], diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index c8ef23b0ef2b..e5931ec84fac 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -380,7 +380,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 8e1c14703fff..8e4567bfcec3 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -42,7 +42,6 @@ LossKwargs, auto_docstring, is_torch_flex_attn_available, - is_torchdynamo_compiling, logger, ) from ..bart.modeling_bart import ( @@ -197,7 +196,7 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 44c9a758f120..804f4df509a9 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -26,6 +26,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -226,13 +227,16 @@ def _init_weight(self): self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) class InformerValueEmbedding(nn.Module): @@ -264,6 +268,120 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + def eager_attn_forward( module: nn.Module, @@ -304,6 +422,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[InformerConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -320,6 +439,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -330,10 +456,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -354,42 +481,37 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attn_forward attention_type = self.config._attn_implementation @@ -441,6 +563,7 @@ def __init__( is_decoder: bool = False, sampling_factor: int = 5, bias: bool = True, + layer_idx: Optional[int] = None, ): super().__init__() self.factor = sampling_factor @@ -456,6 +579,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -473,6 +597,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -481,45 +606,43 @@ def forward( is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) @@ -749,7 +872,7 @@ def forward( class InformerDecoderLayer(nn.Module): - def __init__(self, config: InformerConfig): + def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model self.dropout = config.dropout @@ -763,6 +886,7 @@ def __init__(self, config: InformerConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -776,6 +900,7 @@ def __init__(self, config: InformerConfig): dropout=config.attention_dropout, sampling_factor=config.sampling_factor, is_decoder=True, + layer_idx=layer_idx, ) else: self.self_attn = InformerAttention( @@ -784,6 +909,7 @@ def __init__(self, config: InformerConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) def forward( @@ -794,9 +920,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -815,47 +942,43 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -871,7 +994,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -1025,33 +1148,6 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) - def _update_full_mask( - self, - attention_mask: Union[torch.Tensor, None], - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) - - return attention_mask - class InformerDecoder(InformerPreTrainedModel): """ @@ -1073,7 +1169,7 @@ def __init__(self, config: InformerConfig): self.embed_positions = InformerSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model ) - self.layers = nn.ModuleList([InformerDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([InformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1093,6 +1189,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -1150,6 +1247,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1159,9 +1259,22 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_shape = inputs_embeds.size()[:-1] + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device + ) _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( @@ -1195,7 +1308,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1215,8 +1328,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -1229,6 +1340,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1240,14 +1352,15 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1260,6 +1373,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1274,90 +1390,6 @@ def forward( cross_attentions=all_cross_attentions, ) - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, - _unsupported_features: bool, - ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - # Other attention flavors support in-built causal (when `mask is None`) - # while we need to create our specific block mask regardless - elif attention_mask is None: - attention_mask = make_flex_block_causal_mask( - torch.ones( - size=(input_shape), - device=inputs_embeds.device, - ) - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - return attention_mask - - def _update_cross_attn_mask( - self, - encoder_hidden_states: Union[torch.Tensor, None], - encoder_attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(encoder_attention_mask, torch.Tensor): - encoder_attention_mask = make_flex_block_causal_mask( - encoder_attention_mask, - query_length=input_shape[-1], - is_causal=False, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - - return encoder_attention_mask - @auto_docstring class InformerModel(InformerPreTrainedModel): @@ -1518,6 +1550,7 @@ def forward( output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Seq2SeqTSModelOutput, Tuple]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): @@ -1696,6 +1729,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1810,6 +1844,7 @@ def forward( output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Seq2SeqTSModelOutput, Tuple]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): @@ -1981,6 +2016,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, return_dict=return_dict, + cache_position=cache_position, ) prediction_loss = None diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index 88ce69085d2c..cec23ba220ed 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -20,7 +20,13 @@ import torch from torch import nn -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask +from ...cache_utils import EncoderDecoderCache +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_attention_mask_for_sdpa, + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) from ...modeling_outputs import ( BaseModelOutput, ) @@ -28,6 +34,7 @@ from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import ( auto_docstring, + is_torch_flex_attn_available, ) from ..bart.modeling_bart import BartAttention from ..time_series_transformer.modeling_time_series_transformer import ( @@ -47,6 +54,10 @@ from .configuration_informer import InformerConfig +if is_torch_flex_attn_available(): + from ...integrations.flex_attention import make_flex_block_causal_mask + + def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor: """ Computes the negative log likelihood loss from input distribution with respect to target. @@ -98,6 +109,120 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class InformerAttention(BartAttention): pass @@ -116,6 +241,7 @@ def __init__( is_decoder: bool = False, sampling_factor: int = 5, bias: bool = True, + layer_idx: Optional[int] = None, ): super().__init__() self.factor = sampling_factor @@ -131,6 +257,7 @@ def __init__( ) self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder + self.layer_idx = layer_idx self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -148,6 +275,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" @@ -156,45 +284,43 @@ def forward( is_cross_attention = key_value_states is not None bsz, tgt_len, _ = hidden_states.size() + src_len = key_value_states.shape[1] if is_cross_attention else tgt_len + kv_input_shape = (bsz, src_len, -1, self.head_dim) # get query proj query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) @@ -368,7 +494,7 @@ def __init__(self, config: InformerConfig): class InformerDecoderLayer(TimeSeriesTransformerDecoderLayer): - def __init__(self, config: InformerConfig): + def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None): super().__init__(config) del self.self_attn @@ -380,6 +506,7 @@ def __init__(self, config: InformerConfig): dropout=config.attention_dropout, sampling_factor=config.sampling_factor, is_decoder=True, + layer_idx=layer_idx, ) else: self.self_attn = InformerAttention( @@ -388,6 +515,7 @@ def __init__(self, config: InformerConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) @@ -546,7 +674,7 @@ def __init__(self, config: InformerConfig): self.embed_positions = InformerSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model ) - self.layers = nn.ModuleList([InformerDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([InformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 1c495c14c6c7..c0d24c8f7dbb 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -30,8 +30,6 @@ AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import ( FlashAttentionKwargs, @@ -562,7 +560,7 @@ def _init_weights(self, module): module.weight.data.fill_(1.0) module.bias.data.zero_() - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], @@ -591,7 +589,7 @@ def _update_full_mask( return attention_mask - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], @@ -727,7 +725,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask def _update_cross_attn_mask( self, encoder_hidden_states: Union[torch.Tensor, None], diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index a2d9138ab34b..1f13fc55e3ce 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -510,7 +510,7 @@ def dummy_inputs(self): } return dummy_inputs - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], @@ -538,7 +538,7 @@ def _update_full_mask( return attention_mask - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 5b5e0ce6ae92..cb73384d59cb 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -537,7 +537,7 @@ def dummy_inputs(self): } return dummy_inputs - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], @@ -566,7 +566,7 @@ def _update_full_mask( return attention_mask - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], @@ -702,7 +702,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask def _update_cross_attn_mask( self, encoder_hidden_states: Union[torch.Tensor, None], diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index a8bb8d1c5f7d..e67b84a79c85 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -91,7 +91,7 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], @@ -119,7 +119,7 @@ def _update_full_mask( return attention_mask - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], @@ -259,7 +259,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask def _update_cross_attn_mask( self, encoder_hidden_states: Union[torch.Tensor, None], diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index 1bc4326c77ec..aee0e2f4ffd1 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -77,7 +77,7 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], @@ -106,7 +106,7 @@ def _update_full_mask( return attention_mask - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], @@ -242,7 +242,7 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask def _update_cross_attn_mask( self, encoder_hidden_states: Union[torch.Tensor, None], diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index cc3e99c72dd0..5358ae7cb4c9 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -22,6 +22,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_attn_mask_utils import ( _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, @@ -39,11 +40,7 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput -from ...utils import ( - auto_docstring, - is_torch_flex_attn_available, - logging, -) +from ...utils import auto_docstring, is_torch_flex_attn_available, logging from .configuration_time_series_transformer import TimeSeriesTransformerConfig @@ -258,13 +255,16 @@ def _init_weight(self): self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) class TimeSeriesValueEmbedding(nn.Module): @@ -317,6 +317,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[TimeSeriesTransformerConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -333,6 +334,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -343,10 +351,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -367,42 +376,37 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attn_forward attention_type = self.config._attn_implementation @@ -443,7 +447,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer, BART->TIME_SERIES_TRANSFORMER class TimeSeriesTransformerEncoderLayer(nn.Module): - def __init__(self, config: TimeSeriesTransformerConfig): + def __init__(self, config: TimeSeriesTransformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -452,6 +456,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, + layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -514,7 +519,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER class TimeSeriesTransformerDecoderLayer(nn.Module): - def __init__(self, config: TimeSeriesTransformerConfig): + def __init__(self, config: TimeSeriesTransformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -525,6 +530,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -537,6 +543,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -551,9 +558,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -572,47 +580,43 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -628,7 +632,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -658,6 +662,120 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + past_key_values_length: int, + _unsupported_features: bool, + ): + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( + attention_mask, + input_shape, + inputs_embeds, + past_key_values_length, + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_shape), + device=inputs_embeds.device, + ) + ) + else: + # 4d mask is passed through the layers + attention_mask = _prepare_4d_causal_attention_mask( + attention_mask, input_shape, inputs_embeds, past_key_values_length + ) + + return attention_mask + + # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel): """ @@ -797,34 +915,6 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) - # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask - def _update_full_mask( - self, - attention_mask: Union[torch.Tensor, None], - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) - - return attention_mask - class TimeSeriesTransformerDecoder(TimeSeriesTransformerPreTrainedModel): """ @@ -846,7 +936,9 @@ def __init__(self, config: TimeSeriesTransformerConfig): self.embed_positions = TimeSeriesSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model ) - self.layers = nn.ModuleList([TimeSeriesTransformerDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [TimeSeriesTransformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -866,6 +958,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -923,6 +1016,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -932,9 +1028,22 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_shape = inputs_embeds.size()[:-1] + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device + ) _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( @@ -968,7 +1077,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -988,8 +1097,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -1002,6 +1109,7 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( @@ -1013,14 +1121,15 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1033,6 +1142,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1047,92 +1159,6 @@ def forward( cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, - _unsupported_features: bool, - ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - # Other attention flavors support in-built causal (when `mask is None`) - # while we need to create our specific block mask regardless - elif attention_mask is None: - attention_mask = make_flex_block_causal_mask( - torch.ones( - size=(input_shape), - device=inputs_embeds.device, - ) - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - return attention_mask - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask - def _update_cross_attn_mask( - self, - encoder_hidden_states: Union[torch.Tensor, None], - encoder_attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(encoder_attention_mask, torch.Tensor): - encoder_attention_mask = make_flex_block_causal_mask( - encoder_attention_mask, - query_length=input_shape[-1], - is_causal=False, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - - return encoder_attention_mask - @auto_docstring class TimeSeriesTransformerModel(TimeSeriesTransformerPreTrainedModel): @@ -1293,6 +1319,7 @@ def forward( output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Seq2SeqTSModelOutput, Tuple]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): @@ -1471,6 +1498,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1553,6 +1581,7 @@ def forward( output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Seq2SeqTSModelOutput, Tuple]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): @@ -1724,6 +1753,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, return_dict=return_dict, + cache_position=cache_position, ) prediction_loss = None From b2a29872369636b75a75e125816e14778409675c Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 20 May 2025 11:28:11 +0200 Subject: [PATCH 51/68] last models possibly and fixes on style/copies --- src/transformers/models/bart/modeling_bart.py | 6 +- .../modeling_bigbird_pegasus.py | 11 +- .../models/biogpt/modular_biogpt.py | 6 +- .../models/blenderbot/modeling_blenderbot.py | 470 +++++++++----- .../modeling_blenderbot_small.py | 610 +++++++++++------- .../models/m2m_100/modeling_m2m_100.py | 7 +- .../models/marian/modeling_marian.py | 6 +- .../models/mbart/modeling_mbart.py | 7 +- src/transformers/models/mvp/modeling_mvp.py | 6 +- .../models/pegasus/modeling_pegasus.py | 559 ++++++++++------ .../models/pegasus_x/modeling_pegasus_x.py | 533 ++++++++++----- .../models/plbart/modular_plbart.py | 7 +- .../blenderbot/test_modeling_blenderbot.py | 8 +- .../test_modeling_blenderbot_small.py | 8 +- tests/models/pegasus/test_modeling_pegasus.py | 6 +- .../pegasus_x/test_modeling_pegasus_x.py | 6 +- 16 files changed, 1440 insertions(+), 816 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index e6c5fe6ae1ae..90988f56b5d8 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -569,7 +569,11 @@ def _update_causal_mask( _unsupported_features: bool = False, dropout: float = 0.0, ): - if self.config._attn_implementation == "flex_attention" and not _unsupported_features and (dropout == 0 or not self.training): + if ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index fab94c336a23..081f0da0fbc6 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1620,7 +1620,11 @@ def _update_causal_mask( _unsupported_features: bool = False, dropout: float = 0.0, ): - if self.config._attn_implementation == "flex_attention" and not _unsupported_features and (dropout == 0 or not self.training): + if ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -1789,7 +1793,6 @@ def _update_cross_attn_mask( return encoder_attention_mask - class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a @@ -2118,7 +2121,9 @@ def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embed config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([BigBirdPegasusDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [BigBirdPegasusDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 8e4567bfcec3..64528fcfc34d 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -206,7 +206,11 @@ def _update_causal_mask( _unsupported_features: bool = False, dropout: float = 0.0, ): - if self.config._attn_implementation == "flex_attention" and not _unsupported_features and (dropout == 0 or not self.training): + if ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 57b6334d3720..ebd37b964da0 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -32,8 +32,6 @@ AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -45,7 +43,12 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from ..blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel from .configuration_blenderbot import BlenderbotConfig @@ -53,6 +56,7 @@ if is_torch_flex_attn_available(): from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask + logger = logging.get_logger(__name__) @@ -147,7 +151,7 @@ def __init__( is_decoder: bool = False, bias: bool = True, is_causal: bool = False, - config: Optional[BartConfig] = None, + config: Optional[BlenderbotConfig] = None, layer_idx: Optional[int] = None, ): super().__init__() @@ -347,7 +351,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT class BlenderbotDecoderLayer(nn.Module): - def __init__(self, config: BlenderbotConfig): + def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -358,6 +362,7 @@ def __init__(self, config: BlenderbotConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -370,6 +375,7 @@ def __init__(self, config: BlenderbotConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -384,9 +390,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -405,47 +412,42 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -461,7 +463,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -474,6 +476,8 @@ class BlenderbotPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -485,6 +489,9 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() @property def dummy_inputs(self): @@ -497,6 +504,216 @@ def dummy_inputs(self): } return dummy_inputs + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + _unsupported_features: bool = False, + dropout: float = 0.0, + ): + if ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not _unsupported_features + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + dropout: float = 0.0, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class BlenderbotEncoder(BlenderbotPreTrainedModel): """ @@ -672,34 +889,6 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) - # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask - def _update_full_mask( - self, - attention_mask: Union[torch.Tensor, None], - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) - - return attention_mask - class BlenderbotDecoder(BlenderbotPreTrainedModel): """ @@ -729,7 +918,9 @@ def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([BlenderbotDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [BlenderbotDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -756,6 +947,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position: Optional[torch.Tensor] = None, ): r""" Args: @@ -822,6 +1014,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -831,29 +1026,65 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() + input = input_ids + input_shape = input.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None - attention_mask = self._update_causal_mask( + causal_mask = self._update_causal_mask( attention_mask, - input_shape, inputs_embeds, - past_key_values_length, + cache_position, + self_attn_cache, _unsupported_features, + self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, @@ -861,26 +1092,22 @@ def forward( input_shape, inputs_embeds, _unsupported_features, + self.config.attention_dropout, ) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) - - hidden_states = inputs_embeds + positions + position_ids = self.embed_positions( + (batch_size, seq_length), past_key_values_length, position_ids=cache_position + ) + hidden_states = inputs_embeds + position_ids hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -899,13 +1126,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -913,25 +1138,27 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -947,6 +1174,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -961,92 +1191,6 @@ def forward( cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, - _unsupported_features: bool, - ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - # Other attention flavors support in-built causal (when `mask is None`) - # while we need to create our specific block mask regardless - elif attention_mask is None: - attention_mask = make_flex_block_causal_mask( - torch.ones( - size=(input_shape), - device=inputs_embeds.device, - ) - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - return attention_mask - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask - def _update_cross_attn_mask( - self, - encoder_hidden_states: Union[torch.Tensor, None], - encoder_attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(encoder_attention_mask, torch.Tensor): - encoder_attention_mask = make_flex_block_causal_mask( - encoder_attention_mask, - query_length=input_shape[-1], - is_causal=False, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - - return encoder_attention_mask - @auto_docstring class BlenderbotModel(BlenderbotPreTrainedModel): @@ -1109,6 +1253,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1187,6 +1332,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1285,6 +1431,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1371,6 +1518,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias @@ -1472,6 +1620,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1521,6 +1670,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 345f2b2cfe1f..c99549923f58 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -24,12 +24,12 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -41,12 +41,17 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_blenderbot_small import BlenderbotSmallConfig if is_torch_flex_attn_available(): - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -78,13 +83,16 @@ class BlenderbotSmallLearnedPositionalEmbedding(nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int): super().__init__(num_embeddings, embedding_dim) - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ): """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.bart.modeling_bart.eager_attn_forward @@ -128,6 +136,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[BlenderbotSmallConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -144,6 +153,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -154,10 +170,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -178,42 +195,37 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attn_forward attention_type = self.config._attn_implementation @@ -254,7 +266,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL class BlenderbotSmallEncoderLayer(nn.Module): - def __init__(self, config: BlenderbotSmallConfig): + def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -263,6 +275,7 @@ def __init__(self, config: BlenderbotSmallConfig): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, + layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -325,7 +338,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL class BlenderbotSmallDecoderLayer(nn.Module): - def __init__(self, config: BlenderbotSmallConfig): + def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -336,6 +349,7 @@ def __init__(self, config: BlenderbotSmallConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -348,6 +362,7 @@ def __init__(self, config: BlenderbotSmallConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -362,9 +377,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -383,47 +399,43 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -439,7 +451,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -452,6 +464,8 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -463,6 +477,9 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() @property def dummy_inputs(self): @@ -475,6 +492,216 @@ def dummy_inputs(self): } return dummy_inputs + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + _unsupported_features: bool = False, + dropout: float = 0.0, + ): + if ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not _unsupported_features + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + dropout: float = 0.0, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask + class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): """ @@ -646,34 +873,6 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) - # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask - def _update_full_mask( - self, - attention_mask: Union[torch.Tensor, None], - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) - - return attention_mask - class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel): """ @@ -701,7 +900,9 @@ def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embe config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([BlenderbotSmallDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [BlenderbotSmallDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -728,6 +929,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -793,6 +995,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -802,29 +1007,65 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() + input = input_ids + input_shape = input.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None - attention_mask = self._update_causal_mask( + causal_mask = self._update_causal_mask( attention_mask, - input_shape, inputs_embeds, - past_key_values_length, + cache_position, + self_attn_cache, _unsupported_features, + self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, @@ -832,29 +1073,24 @@ def forward( input_shape, inputs_embeds, _unsupported_features, + self.config.attention_dropout, ) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + position_ids = self.embed_positions( + (batch_size, seq_length), past_key_values_length, position_ids=cache_position + ) # BlenderbotSmall applies layer norm on hidden_states inputs_embeds = self.layernorm_embedding(inputs_embeds) - hidden_states = inputs_embeds + positions - + hidden_states = inputs_embeds + position_ids hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -873,13 +1109,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -887,25 +1121,27 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -918,6 +1154,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -932,137 +1171,6 @@ def forward( cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, - _unsupported_features: bool, - ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - # Other attention flavors support in-built causal (when `mask is None`) - # while we need to create our specific block mask regardless - elif attention_mask is None: - attention_mask = make_flex_block_causal_mask( - torch.ones( - size=(input_shape), - device=inputs_embeds.device, - ) - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - return attention_mask - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, - _unsupported_features: bool, - ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - # Other attention flavors support in-built causal (when `mask is None`) - # while we need to create our specific block mask regardless - elif attention_mask is None: - attention_mask = make_flex_block_causal_mask( - torch.ones( - size=(input_shape), - device=inputs_embeds.device, - ) - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - return attention_mask - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask - def _update_cross_attn_mask( - self, - encoder_hidden_states: Union[torch.Tensor, None], - encoder_attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(encoder_attention_mask, torch.Tensor): - encoder_attention_mask = make_flex_block_causal_mask( - encoder_attention_mask, - query_length=input_shape[-1], - is_causal=False, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - - return encoder_attention_mask - @auto_docstring class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel): @@ -1112,6 +1220,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1190,6 +1299,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1273,6 +1383,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1359,6 +1470,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias @@ -1460,6 +1572,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1509,6 +1622,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index c0d24c8f7dbb..7889289ba8ea 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -588,7 +588,6 @@ def _update_full_mask( return attention_mask - # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, @@ -599,7 +598,11 @@ def _update_causal_mask( _unsupported_features: bool = False, dropout: float = 0.0, ): - if self.config._attn_implementation == "flex_attention" and not _unsupported_features and (dropout == 0 or not self.training): + if ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 1f13fc55e3ce..1e5635bd6b6e 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -548,7 +548,11 @@ def _update_causal_mask( _unsupported_features: bool = False, dropout: float = 0.0, ): - if self.config._attn_implementation == "flex_attention" and not _unsupported_features and (dropout == 0 or not self.training): + if ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index cb73384d59cb..5479089ee3f7 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -565,7 +565,6 @@ def _update_full_mask( return attention_mask - # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, @@ -576,7 +575,11 @@ def _update_causal_mask( _unsupported_features: bool = False, dropout: float = 0.0, ): - if self.config._attn_implementation == "flex_attention" and not _unsupported_features and (dropout == 0 or not self.training): + if ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 1f66bccb9a13..5d270380f6fa 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -60,14 +60,14 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MVP -class MVPLearnedPositionalEmbedding(nn.Embedding): +# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->Mvp +class MvpLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. """ def __init__(self, num_embeddings: int, embedding_dim: int): - # MVP is set up so that if padding_idx is specified then offset the embedding ids by 2 + # Mvp is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index f755a2b9d15f..ed3e0050598b 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -25,12 +25,12 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -42,12 +42,17 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_pegasus import PegasusConfig if is_torch_flex_attn_available(): - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -93,13 +98,16 @@ def _init_weight(self): self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.bart.modeling_bart.eager_attn_forward @@ -143,6 +151,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[PegasusConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -159,6 +168,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -169,10 +185,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -193,42 +210,37 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attn_forward attention_type = self.config._attn_implementation @@ -338,7 +350,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS class PegasusDecoderLayer(nn.Module): - def __init__(self, config: PegasusConfig): + def __init__(self, config: PegasusConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -349,6 +361,7 @@ def __init__(self, config: PegasusConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -361,6 +374,7 @@ def __init__(self, config: PegasusConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -375,9 +389,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -396,47 +411,42 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -452,7 +462,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -465,6 +475,8 @@ class PegasusPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_flex_attn = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -478,6 +490,219 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + _unsupported_features: bool = False, + dropout: float = 0.0, + ): + if ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not _unsupported_features + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + dropout: float = 0.0, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask class PegasusEncoder(PegasusPreTrainedModel): @@ -683,34 +908,6 @@ def forward( last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions ) - # Copied from transformers.models.bart.modeling_bart.BartEncoder._update_full_mask - def _update_full_mask( - self, - attention_mask: Union[torch.Tensor, None], - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) - - return attention_mask - class PegasusDecoder(PegasusPreTrainedModel): """ @@ -739,7 +936,7 @@ def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = config.d_model, self.padding_idx, ) - self.layers = nn.ModuleList([PegasusDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([PegasusDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -796,6 +993,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -861,6 +1059,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -870,29 +1071,68 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() + input = input_ids + input_shape = input.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input) + + # important to apply scale outside of `if` in case users pass `embeds` + inputs_embeds = inputs_embeds * self.embed_scale + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None - attention_mask = self._update_causal_mask( + causal_mask = self._update_causal_mask( attention_mask, - input_shape, inputs_embeds, - past_key_values_length, + cache_position, + self_attn_cache, _unsupported_features, + self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, @@ -900,27 +1140,19 @@ def forward( input_shape, inputs_embeds, _unsupported_features, + self.config.attention_dropout, ) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) - + positions = self.embed_positions((batch_size, seq_length), past_key_values_length, position_ids=cache_position) hidden_states = inputs_embeds + positions - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -939,13 +1171,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -953,25 +1183,27 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -986,6 +1218,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1000,92 +1235,6 @@ def forward( cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, - _unsupported_features: bool, - ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - # Other attention flavors support in-built causal (when `mask is None`) - # while we need to create our specific block mask regardless - elif attention_mask is None: - attention_mask = make_flex_block_causal_mask( - torch.ones( - size=(input_shape), - device=inputs_embeds.device, - ) - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - return attention_mask - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask - def _update_cross_attn_mask( - self, - encoder_hidden_states: Union[torch.Tensor, None], - encoder_attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(encoder_attention_mask, torch.Tensor): - encoder_attention_mask = make_flex_block_causal_mask( - encoder_attention_mask, - query_length=input_shape[-1], - is_causal=False, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - - return encoder_attention_mask - @auto_docstring class PegasusModel(PegasusPreTrainedModel): @@ -1158,6 +1307,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1237,6 +1387,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1343,6 +1494,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1418,6 +1570,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias @@ -1544,6 +1697,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1593,6 +1747,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 8a16e232981e..080718d66734 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -25,12 +25,12 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_outputs import ( @@ -41,12 +41,17 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_torch_flex_attn_available, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_pegasus_x import PegasusXConfig if is_torch_flex_attn_available(): - from ...integrations.flex_attention import make_flex_block_causal_mask + from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask logger = logging.get_logger(__name__) @@ -110,20 +115,24 @@ def __init__(self, embed_dim, max_scale: int = 10000.0): self.max_scale = max_scale @torch.no_grad() - def forward(self, input_embeds: torch.Tensor, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_embeds: torch.Tensor, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" batch_size, seq_len = input_embeds.shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=input_embeds.device - )[:, None] + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=input_embeds.device + )[:, None] + pe = torch.zeros((seq_len, self.embed_dim), device=input_embeds.device, dtype=input_embeds.dtype) half_d_feature = self.embed_dim // 2 div_term = torch.exp( torch.arange(half_d_feature, device=input_embeds.device, dtype=torch.int64).type_as(input_embeds) * -(np.log(float(self.max_scale)) / (half_d_feature - 1)) ) - pe[:, :half_d_feature] = torch.sin(positions * div_term) - pe[:, half_d_feature:] = torch.cos(positions * div_term) + pe[:, :half_d_feature] = torch.sin(position_ids * div_term) + pe[:, half_d_feature:] = torch.cos(position_ids * div_term) return pe[None].expand(batch_size, -1, -1) @@ -168,6 +177,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[PegasusXConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -184,6 +194,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -194,10 +211,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, # TODO: we need a refactor so that the different attention modules can get their specific kwargs # ATM, we have mixed things encoder, decoder, and encoder-decoder attn **kwargs: Unpack[FlashAttentionKwargs], @@ -218,42 +236,37 @@ def forward( # get query proj query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(*kv_input_shape).transpose(1, 2) + value_states = value_states.view(*kv_input_shape).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True attention_interface: Callable = eager_attn_forward attention_type = self.config._attn_implementation @@ -642,7 +655,7 @@ def unpad_local_tokens(cls, padded_hidden_states, block_size): class PegasusXDecoderLayer(nn.Module): - def __init__(self, config: PegasusXConfig): + def __init__(self, config: PegasusXConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -653,6 +666,7 @@ def __init__(self, config: PegasusXConfig): is_decoder=True, bias=False, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -666,6 +680,7 @@ def __init__(self, config: PegasusXConfig): is_decoder=True, bias=False, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -678,9 +693,10 @@ def forward( attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -696,45 +712,40 @@ def forward( Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache: Whether to us KV cache for decoding + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -750,7 +761,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -765,6 +776,8 @@ class PegasusXPreTrainedModel(PreTrainedModel): # TODO: Flaky logits _supports_sdpa = False _supports_flex_attn = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -774,6 +787,219 @@ def _init_weights(self, module): module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask + def _update_full_mask( + self, + attention_mask: Union[torch.Tensor, None], + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + ): + if attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + attention_mask = attention_mask if 0 in attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & head_mask can not be supported when using SDPA, fall back to + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (self.config.attention_dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype) + + return attention_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Optional[Union[torch.Tensor, "BlockMask"]], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + _unsupported_features: bool = False, + dropout: float = 0.0, + ): + if ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + # Other attention flavors support in-built causal (when `mask is None`) + # while we need to create our specific block mask regardless + elif attention_mask is None: + attention_mask = make_flex_block_causal_mask( + torch.ones( + size=(input_tensor.shape[0], input_tensor.shape[1]), + device=attention_mask.device, + ) + ) + return attention_mask + + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not _unsupported_features + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + def _update_cross_attn_mask( + self, + encoder_hidden_states: Union[torch.Tensor, None], + encoder_attention_mask: Union[torch.Tensor, None], + input_shape: torch.Size, + inputs_embeds: torch.Tensor, + _unsupported_features: bool, + dropout: float = 0.0, + ): + # expand encoder attention mask + if encoder_hidden_states is not None and encoder_attention_mask is not None: + if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None + elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on + # the manual implementation that requires a 4D causal mask in all cases. + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=input_shape[-1], + ) + elif ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): + if isinstance(encoder_attention_mask, torch.Tensor): + encoder_attention_mask = make_flex_block_causal_mask( + encoder_attention_mask, + query_length=input_shape[-1], + is_causal=False, + ) + else: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + encoder_attention_mask = _prepare_4d_attention_mask( + encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + ) + + return encoder_attention_mask class PegasusXEncoder(PegasusXPreTrainedModel): @@ -1013,7 +1239,7 @@ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] ) self.embed_positions = PegasusXSinusoidalPositionalEmbedding(config.d_model) - self.layers = nn.ModuleList([PegasusXDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([PegasusXDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1038,6 +1264,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -1092,6 +1319,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1101,29 +1331,65 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") elif input_ids is not None: - input_shape = input_ids.size() + input = input_ids + input_shape = input.shape input_ids = input_ids.view(-1, input_shape[-1]) elif inputs_embeds is not None: input_shape = inputs_embeds.size()[:-1] + input = inputs_embeds[:, :, -1] else: raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) _unsupported_features = output_attentions is True - attention_mask = self._update_causal_mask( + causal_mask = self._update_causal_mask( attention_mask, - input_shape, inputs_embeds, - past_key_values_length, + cache_position, + self_attn_cache, _unsupported_features, + self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, @@ -1131,29 +1397,21 @@ def forward( input_shape, inputs_embeds, _unsupported_features, + self.config.attention_dropout, ) # embed positions - positions = self.embed_positions(inputs_embeds, past_key_values_length) - - positions = positions.to(inputs_embeds.device) - - hidden_states = inputs_embeds + positions - + position_ids = cache_position.unsqueeze(1) + position_ids = self.embed_positions(inputs_embeds, past_key_values_length, position_ids) + position_ids = position_ids.to(inputs_embeds.device) + hidden_states = inputs_embeds + position_ids hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -1164,33 +1422,33 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1205,6 +1463,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1219,92 +1480,6 @@ def forward( cross_attentions=all_cross_attentions, ) - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_causal_mask - def _update_causal_mask( - self, - attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - past_key_values_length: int, - _unsupported_features: bool, - ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(attention_mask, torch.Tensor): - attention_mask = make_flex_block_causal_mask(attention_mask) - # Other attention flavors support in-built causal (when `mask is None`) - # while we need to create our specific block mask regardless - elif attention_mask is None: - attention_mask = make_flex_block_causal_mask( - torch.ones( - size=(input_shape), - device=inputs_embeds.device, - ) - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - - return attention_mask - - # Copied from transformers.models.bart.modeling_bart.BartDecoder._update_cross_attn_mask - def _update_cross_attn_mask( - self, - encoder_hidden_states: Union[torch.Tensor, None], - encoder_attention_mask: Union[torch.Tensor, None], - input_shape: torch.Size, - inputs_embeds: torch.Tensor, - _unsupported_features: bool, - ): - # expand encoder attention mask - if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: - encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( - encoder_attention_mask, - inputs_embeds.dtype, - tgt_len=input_shape[-1], - ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): - if isinstance(encoder_attention_mask, torch.Tensor): - encoder_attention_mask = make_flex_block_causal_mask( - encoder_attention_mask, - query_length=input_shape[-1], - is_causal=False, - ) - else: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] - ) - - return encoder_attention_mask - @auto_docstring class PegasusXModel(PegasusXPreTrainedModel): @@ -1378,6 +1553,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1448,6 +1624,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1533,6 +1710,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1578,6 +1756,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index aee0e2f4ffd1..716374a2dc89 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -105,7 +105,6 @@ def _update_full_mask( return attention_mask - # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask def _update_causal_mask( self, @@ -116,7 +115,11 @@ def _update_causal_mask( _unsupported_features: bool = False, dropout: float = 0.0, ): - if self.config._attn_implementation == "flex_attention" and not _unsupported_features and (dropout == 0 or not self.training): + if ( + self.config._attn_implementation == "flex_attention" + and not _unsupported_features + and (dropout == 0 or not self.training) + ): if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) diff --git a/tests/models/blenderbot/test_modeling_blenderbot.py b/tests/models/blenderbot/test_modeling_blenderbot.py index 0c54544ca66f..bec16cf5dc13 100644 --- a/tests/models/blenderbot/test_modeling_blenderbot.py +++ b/tests/models/blenderbot/test_modeling_blenderbot.py @@ -356,7 +356,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -500,9 +500,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py index e7916c6c1b69..8d75649d8cc1 100644 --- a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py +++ b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py @@ -366,7 +366,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -509,9 +509,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py index 9f0ed926ef28..af119c41d335 100644 --- a/tests/models/pegasus/test_modeling_pegasus.py +++ b/tests/models/pegasus/test_modeling_pegasus.py @@ -542,9 +542,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/pegasus_x/test_modeling_pegasus_x.py b/tests/models/pegasus_x/test_modeling_pegasus_x.py index 5a63cc56181f..a6bf913e4c2e 100644 --- a/tests/models/pegasus_x/test_modeling_pegasus_x.py +++ b/tests/models/pegasus_x/test_modeling_pegasus_x.py @@ -78,7 +78,7 @@ def __init__( hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=20, + max_position_embeddings=50, eos_token_id=2, pad_token_id=1, bos_token_id=0, @@ -676,7 +676,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -819,7 +819,7 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, use_cache=True)["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() From 917d3c9a5c22ab031a3dbf2439b22153a6f7cd7e Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 20 May 2025 11:50:05 +0200 Subject: [PATCH 52/68] some post merge fixes --- src/transformers/models/bart/modeling_bart.py | 8 ++++---- .../modeling_bigbird_pegasus.py | 4 ++-- .../patchtsmixer/modeling_patchtsmixer.py | 20 +++++++++---------- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 7fbac3ee1014..90988f56b5d8 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1091,7 +1091,7 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) # initialize `past_key_values` return_legacy_cache = False @@ -1145,7 +1145,7 @@ def forward( positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) positions = positions.to(inputs_embeds.device) - hidden_states = inputs_embeds + position_ids + hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1178,7 +1178,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1191,7 +1191,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 6668761bf853..081f0da0fbc6 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2330,7 +2330,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -2343,7 +2343,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 6e344ee5737b..a8d82fac19bb 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -29,7 +29,6 @@ from ...processing_utils import Unpack from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import auto_docstring, logging -from ...utils.deprecation import deprecate_kwarg from .configuration_patchtsmixer import PatchTSMixerConfig @@ -279,7 +278,6 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[PatchTSMixerConfig] = None, - layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -296,13 +294,6 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal - self.layer_idx = layer_idx - if layer_idx is None and self.is_decoder: - logger.warning_once( - f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " - "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " - "when creating this class." - ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) @@ -364,8 +355,15 @@ def forward( key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2) - key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) - value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + if self.is_decoder: + # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. + # Further calls to cross_attention layer can then reuse all cross-attention + # key/value_states (first "if" case) + # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of + # all previous decoder key/value_states. Further calls to uni-directional self-attention + # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) + # if encoder bi-directional self-attention `past_key_value` is always `None` + past_key_value = (key_states, value_states) attention_interface: Callable = eager_attn_forward attention_type = self.config._attn_implementation From 776e3ca3c82af649ec164d1c3290b7d68e4ff1f3 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 20 May 2025 11:53:44 +0200 Subject: [PATCH 53/68] more fixes --- .../models/biogpt/modeling_biogpt.py | 47 +------------------ .../models/informer/modeling_informer.py | 2 - .../models/plbart/modeling_plbart.py | 8 ++-- src/transformers/models/sew/modeling_sew.py | 1 - tests/generation/test_utils.py | 1 - 5 files changed, 5 insertions(+), 54 deletions(-) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 88ca68ee1f15..e5931ec84fac 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -652,55 +652,10 @@ def forward( ) use_cache = False - # initialize past_key_values - return_legacy_cache = False - if use_cache and not isinstance(past_key_values, Cache): - return_legacy_cache = True - logger.warning_once( - "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " - "You should pass an instance of `EncoderDecoderCache` instead, e.g. " - "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." - ) - past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - - batch_size, seq_length = inputs_embeds.size()[:-1] - past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 - if cache_position is None: - cache_position = torch.arange( - past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device - ) - - if attention_mask is None and not is_torchdynamo_compiling(): - # required mask seq length can be calculated via length of past cache - mask_seq_length = past_key_values_length + seq_length - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - - self_attn_cache = ( - past_key_values.self_attention_cache - if isinstance(past_key_values, EncoderDecoderCache) - else past_key_values - ) - causal_mask = self._update_causal_mask( - attention_mask, - inputs_embeds, - cache_position, - self_attn_cache, - output_attentions, - ) - - # embed positions - if position_ids is None: - position_ids = cache_position.unsqueeze(0) - - position_ids = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) - - hidden_states = inputs_embeds + position_ids - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = None - next_decoder_cache = None + next_decoder_cache = () if use_cache else None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 36a0db013ee3..804f4df509a9 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1530,7 +1530,6 @@ def get_encoder(self): def get_decoder(self): return self.decoder - # Ignore copy @auto_docstring def forward( self, @@ -1824,7 +1823,6 @@ def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> sliced_params = [p[:, -trailing_n:] for p in params] return self.distribution_output.distribution(sliced_params, loc=loc, scale=scale) - # Ignore copy @auto_docstring def forward( self, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 1350611519da..e67b84a79c85 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -1018,7 +1018,7 @@ def forward( raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) + inputs_embeds = self.embed_tokens(input) # initialize `past_key_values` return_legacy_cache = False @@ -1072,7 +1072,7 @@ def forward( positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position) positions = positions.to(inputs_embeds.device) - hidden_states = inputs_embeds + position_ids + hidden_states = inputs_embeds + positions hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) @@ -1105,7 +1105,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - causal_mask, + attention_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1118,7 +1118,7 @@ def forward( else: layer_outputs = decoder_layer( hidden_states, - attention_mask=causal_mask, + attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 5c2fda2832a8..999c9b23c65b 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -36,7 +36,6 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import auto_docstring, logging -from ...utils.deprecation import deprecate_kwarg from .configuration_sew import SEWConfig diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index d867f493ec63..783e7cb3a7a7 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2150,7 +2150,6 @@ def test_generate_compile_model_forward(self): "return_dict_in_generate": True, "output_scores": True, "compile_config": compile_config, - "use_cache": True, } # 4. get eager + dynamic cache results for future comparison From a27bfb9ab99f340f7acc4616ee15428098ac0b79 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 20 May 2025 13:30:22 +0200 Subject: [PATCH 54/68] make attention interface callable and move warnings there --- src/transformers/modeling_utils.py | 59 +++++++++++++++++++ src/transformers/models/bart/modeling_bart.py | 39 ++++-------- .../modeling_bigbird_pegasus.py | 37 ++++-------- .../models/biogpt/modeling_biogpt.py | 39 ++++-------- .../models/blenderbot/modeling_blenderbot.py | 37 ++++-------- .../modeling_blenderbot_small.py | 37 ++++-------- .../data2vec/modeling_data2vec_audio.py | 44 ++++---------- .../models/hubert/modeling_hubert.py | 39 ++++-------- .../models/informer/modeling_informer.py | 39 ++++-------- .../models/m2m_100/modeling_m2m_100.py | 37 ++++-------- .../models/marian/modeling_marian.py | 37 ++++-------- .../models/mbart/modeling_mbart.py | 37 ++++-------- .../models/musicgen/modeling_musicgen.py | 37 ++++-------- .../modeling_musicgen_melody.py | 37 ++++-------- .../models/nllb_moe/modeling_nllb_moe.py | 37 ++++-------- .../patchtsmixer/modeling_patchtsmixer.py | 37 ++++-------- .../models/patchtst/modeling_patchtst.py | 37 ++++-------- .../models/pegasus/modeling_pegasus.py | 37 ++++-------- .../models/pegasus_x/modeling_pegasus_x.py | 37 ++++-------- .../models/plbart/modeling_plbart.py | 39 ++++-------- src/transformers/models/sew/modeling_sew.py | 39 ++++-------- .../speech_to_text/modeling_speech_to_text.py | 37 ++++-------- .../modeling_time_series_transformer.py | 37 ++++-------- .../models/unispeech/modeling_unispeech.py | 39 ++++-------- .../unispeech_sat/modeling_unispeech_sat.py | 39 ++++-------- .../models/wav2vec2/modeling_wav2vec2.py | 37 ++++-------- 26 files changed, 344 insertions(+), 663 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index b40b7cd2b388..681f667c0d40 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -164,6 +164,11 @@ if is_kernels_available(): from kernels import get_kernel + +if is_torch_flex_attn_available(): + from .integrations.flex_attention import BlockMask + + logger = logging.get_logger(__name__) @@ -6134,6 +6139,60 @@ class AttentionInterface(MutableMapping): def __init__(self): self._local_mapping = {} + def __call__( + self, + attention_type: str, + eager_attention: Callable, + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Union[torch.Tensor, "BlockMask"], + training: bool = False, + dropout: Optional[float] = 0.0, + scaling: Optional[float] = None, + output_attentions: bool = False, + layer_head_mask: Optional[torch.Tensor] = None, + **kwargs, + ): + attention_interface: Callable = eager_attention + if attention_type != "eager": + if (output_attentions or layer_head_mask is not None) and attention_type in [ + "sdpa", + "flash_attention_2", + "flex_attention", + ]: + logger.warning_once( + f"Falling back to eager attention because `{attention_type}` does not support" + f" `output_attentions=True` or `head_mask`." + ) + elif training and dropout > 0 and attention_type == "flex_attention": + logger.warning_once( + f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." + ) + else: + attention_interface = self[attention_type] + + if scaling is None and attention_type == "eager": + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + + return attention_interface( + module, + query, + key, + value, + attention_mask, + dropout=0.0 if not training else dropout, + scaling=scaling, + layer_head_mask=layer_head_mask, + **kwargs, + ) + def __getitem__(self, key): # First check if instance has a local override if key in self._local_mapping: diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 90988f56b5d8..0e7ce66b4425 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -17,7 +17,7 @@ import copy import math import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -244,33 +244,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 081f0da0fbc6..ed5ba4cdb59a 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1307,33 +1307,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index e5931ec84fac..ad50e2d9b839 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -21,7 +21,7 @@ import math from functools import partial -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn @@ -222,33 +222,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index ebd37b964da0..31299bc242e3 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -243,33 +243,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index c99549923f58..eb4b56d8cda2 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -227,33 +227,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 462fdf56023d..ff2d1f360ef9 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -21,7 +21,7 @@ import math import warnings -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -43,7 +43,7 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging +from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available from .configuration_data2vec_audio import Data2VecAudioConfig @@ -51,9 +51,6 @@ from ...integrations.flex_attention import make_flex_block_causal_mask -logger = logging.get_logger(__name__) - - class Data2VecAudioConvLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -310,33 +307,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 6048be6b59d4..a30c7bc664a3 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -20,7 +20,7 @@ # limitations under the License. import warnings -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -367,33 +367,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 804f4df509a9..a00b8834c481 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -19,7 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -513,33 +513,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index c377dfba3424..062c91c58794 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -308,33 +308,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 1e5635bd6b6e..43c350d9a5a3 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -243,33 +243,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 5479089ee3f7..f06c736ea39c 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -253,33 +253,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 19a4001da8bc..3c7525a204bb 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -278,33 +278,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 8e1e3760ad2e..c81e9910e953 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -294,33 +294,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 87fa1328f540..0987b59f024c 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -599,33 +599,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index a8d82fac19bb..5448cb5fa5c8 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -365,33 +365,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index f302381f776e..347865705221 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -162,33 +162,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index ed3e0050598b..2538eb5510b1 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -242,33 +242,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 080718d66734..76f8876d42b6 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -268,33 +268,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index e67b84a79c85..89d1dca42339 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -21,7 +21,7 @@ import copy import math -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -457,33 +457,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 999c9b23c65b..f6de6142b417 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -21,7 +21,7 @@ import math import warnings -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -360,33 +360,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 010bdf165420..a561924a01de 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -301,33 +301,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 5358ae7cb4c9..bd760afb22fe 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -408,33 +408,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index b49b5ffbec49..8302126efe1d 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -22,7 +22,7 @@ import math import warnings from dataclasses import dataclass -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -406,33 +406,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index f41b1bd9c629..60eef3ee80e8 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -22,7 +22,7 @@ import math import warnings from dataclasses import dataclass -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch @@ -409,33 +409,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ae9d03320848..10d849d2365c 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -598,33 +598,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward - attention_type = self.config._attn_implementation - if self.config._attn_implementation != "eager": - if (output_attentions or layer_head_mask is not None) and self.config._attn_implementation in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif self.training and self.dropout > 0 and self.config._attn_implementation == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - - attn_output, attn_weights = attention_interface( - self, - query_states, - key_states, - value_states, - attention_mask, - dropout=0.0 if not self.training else self.dropout, + attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( + attention_type=self.config._attn_implementation, + eager_attention=eager_attn_forward, + module=self, + query=query_states, + key=key_states, + value=value_states, + attention_mask=attention_mask, + training=self.training, + dropout=self.dropout, scaling=self.scaling, + output_attentions=output_attentions, layer_head_mask=layer_head_mask, **kwargs, ) From a066c8566a2ecc42e064355d464b87bd1229662d Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 20 May 2025 13:32:10 +0200 Subject: [PATCH 55/68] style lol --- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 2 +- src/transformers/models/blenderbot/modeling_blenderbot.py | 2 +- .../models/blenderbot_small/modeling_blenderbot_small.py | 2 +- src/transformers/models/m2m_100/modeling_m2m_100.py | 2 +- src/transformers/models/marian/modeling_marian.py | 2 +- src/transformers/models/mbart/modeling_mbart.py | 2 +- src/transformers/models/musicgen/modeling_musicgen.py | 2 +- .../models/musicgen_melody/modeling_musicgen_melody.py | 2 +- src/transformers/models/nllb_moe/modeling_nllb_moe.py | 2 +- src/transformers/models/patchtsmixer/modeling_patchtsmixer.py | 2 +- src/transformers/models/patchtst/modeling_patchtst.py | 2 +- src/transformers/models/pegasus/modeling_pegasus.py | 2 +- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 2 +- .../models/speech_to_text/modeling_speech_to_text.py | 2 +- .../time_series_transformer/modeling_time_series_transformer.py | 2 +- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 2 +- 16 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index ed5ba4cdb59a..afacc4910c34 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -16,7 +16,7 @@ import copy import math -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 31299bc242e3..cec7ab2c798f 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -18,7 +18,7 @@ import math import os import warnings -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index eb4b56d8cda2..8bfb9ad1a3ea 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -16,7 +16,7 @@ import copy import math -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 062c91c58794..5724cb584399 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -15,7 +15,7 @@ """PyTorch M2M100 model.""" import math -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 43c350d9a5a3..edaf0567052b 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -16,7 +16,7 @@ import copy import math -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index f06c736ea39c..cef515b03104 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -16,7 +16,7 @@ import copy import math -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.utils.checkpoint diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 3c7525a204bb..ebd9de944d31 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -19,7 +19,7 @@ import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index c81e9910e953..ef5111f7faf8 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -19,7 +19,7 @@ import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 0987b59f024c..cb73eecf8cda 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -15,7 +15,7 @@ """PyTorch NLLB-MoE model.""" import math -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 5448cb5fa5c8..18ddd1760fa6 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch import torch.nn as nn diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 347865705221..456fe3364151 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from torch import nn diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 2538eb5510b1..b4c90a494c10 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -16,7 +16,7 @@ import copy import math -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 76f8876d42b6..4686c0b3513b 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -16,7 +16,7 @@ import dataclasses import math -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index a561924a01de..b49d5d88a68c 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -15,7 +15,7 @@ """PyTorch Speech2Text model.""" import math -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from torch import nn diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index bd760afb22fe..d79d56e76fd4 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -15,7 +15,7 @@ # limitations under the License. """PyTorch Time Series Transformer model.""" -from typing import Callable, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 10d849d2365c..32ee5a45598f 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -17,7 +17,7 @@ import math import warnings from dataclasses import dataclass -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import numpy as np import torch From ece5b09003cb247ede738536dfa795d4eab08531 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Tue, 20 May 2025 14:15:36 +0200 Subject: [PATCH 56/68] add comment to "unsupported" --- src/transformers/models/autoformer/modeling_autoformer.py | 3 +++ src/transformers/models/bart/modeling_bart.py | 7 ++++++- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 3 +++ src/transformers/models/biogpt/modeling_biogpt.py | 4 ++++ src/transformers/models/biogpt/modular_biogpt.py | 4 ++++ src/transformers/models/blenderbot/modeling_blenderbot.py | 6 ++++++ .../models/blenderbot_small/modeling_blenderbot_small.py | 6 ++++++ .../models/data2vec/modeling_data2vec_audio.py | 3 +++ src/transformers/models/hubert/modeling_hubert.py | 6 ++++++ src/transformers/models/informer/modeling_informer.py | 3 +++ src/transformers/models/m2m_100/modeling_m2m_100.py | 6 ++++++ src/transformers/models/marian/modeling_marian.py | 7 ++++++- src/transformers/models/mbart/modeling_mbart.py | 6 ++++++ src/transformers/models/musicgen/modeling_musicgen.py | 3 +++ .../models/musicgen_melody/modeling_musicgen_melody.py | 3 +++ src/transformers/models/nllb_moe/modeling_nllb_moe.py | 6 ++++++ src/transformers/models/pegasus/modeling_pegasus.py | 6 ++++++ src/transformers/models/pegasus_x/modeling_pegasus_x.py | 3 +++ src/transformers/models/plbart/modeling_plbart.py | 7 ++++++- .../models/speech_to_text/modeling_speech_to_text.py | 6 ++++++ .../modeling_time_series_transformer.py | 6 ++++++ src/transformers/models/unispeech/modeling_unispeech.py | 6 ++++++ .../models/unispeech_sat/modeling_unispeech_sat.py | 6 ++++++ src/transformers/models/wav2vec2/modeling_wav2vec2.py | 6 ++++++ 24 files changed, 119 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index afafd5dd3615..ed83b9eefe9d 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -1018,6 +1018,9 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states + embed_pos) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 0e7ce66b4425..a0d67b37a262 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -860,6 +860,9 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, @@ -1107,7 +1110,9 @@ def forward( else past_key_values ) - # TODO: update mask creation with new interface + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index afacc4910c34..3a421cd5dbfc 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2262,6 +2262,9 @@ def forward( else past_key_values ) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index ad50e2d9b839..0d48f82367c9 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -610,12 +610,16 @@ def forward( else past_key_values ) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, output_attentions, + self.config.attention_probs_dropout_prob, ) # embed positions diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 64528fcfc34d..08c1a16218aa 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -441,12 +441,16 @@ def forward( else past_key_values ) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, output_attentions, + self.config.attention_probs_dropout_prob, ) # embed positions diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index cec7ab2c798f..13a03a334e41 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -811,6 +811,9 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, @@ -1062,6 +1065,9 @@ def forward( else past_key_values ) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None causal_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 8bfb9ad1a3ea..442697e9fce7 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -798,6 +798,9 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, @@ -1043,6 +1046,9 @@ def forward( else past_key_values ) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None causal_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index ff2d1f360ef9..96ccf06ee791 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -415,6 +415,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index a30c7bc664a3..0c8957e9e5d5 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -475,6 +475,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, @@ -656,6 +659,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index a00b8834c481..f042be7ad33a 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -1261,6 +1261,9 @@ def forward( past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device ) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 5724cb584399..f4c71621def4 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -869,6 +869,9 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, @@ -1112,6 +1115,9 @@ def forward( else past_key_values ) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index edaf0567052b..0049e8faa0b9 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -819,6 +819,9 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, @@ -1061,7 +1064,9 @@ def forward( else past_key_values ) - # TODO: update mask creation with new interface + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None causal_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index cef515b03104..edaaa5897768 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -852,6 +852,9 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, @@ -1102,6 +1105,9 @@ def forward( else past_key_values ) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None causal_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index ebd9de944d31..4a9f859574bf 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -565,6 +565,9 @@ def forward( if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index ef5111f7faf8..568f72b9cc3f 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -549,6 +549,9 @@ def forward( input_shape = inputs_embeds.size()[:-1] + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index cb73eecf8cda..df705764d11e 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -989,6 +989,9 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, @@ -1239,6 +1242,9 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index b4c90a494c10..879045ab1e5a 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -831,6 +831,9 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, @@ -1110,6 +1113,9 @@ def forward( else past_key_values ) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None causal_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 4686c0b3513b..6ad777faae9e 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -1367,6 +1367,9 @@ def forward( else past_key_values ) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True causal_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 89d1dca42339..a7a27dee5e12 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -668,6 +668,9 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, @@ -1034,7 +1037,9 @@ def forward( else past_key_values ) - # TODO: update mask creation with new interface + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index b49d5d88a68c..48786f4b7f7d 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -662,6 +662,9 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, @@ -894,6 +897,9 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index d79d56e76fd4..b9ef51c62189 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -839,6 +839,9 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states + embed_pos) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, @@ -1030,6 +1033,9 @@ def forward( past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device ) + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 8302126efe1d..d100bbc7c7d5 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -514,6 +514,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, @@ -695,6 +698,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 60eef3ee80e8..d35d3ae25f8a 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -517,6 +517,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, @@ -698,6 +701,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 32ee5a45598f..e4811c3213d4 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -752,6 +752,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, @@ -863,6 +866,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 + # Efficient attention implementations are not able to interact with certain features, + # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). + # In these cases, we fall back to the eager attention to enable the requested feature(s). _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, From ffdc5660fe97c7a5f3cd759dd4aaa78d96cd0f28 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 21 May 2025 11:02:31 +0200 Subject: [PATCH 57/68] remove callable interface and change interface warnings + some copies --- .../integrations/flash_attention.py | 26 ++++++++- .../integrations/flex_attention.py | 28 ++++++++- .../integrations/sdpa_attention.py | 27 ++++++++- src/transformers/modeling_utils.py | 58 ------------------- src/transformers/models/bart/modeling_bart.py | 39 ++++++++----- .../modeling_bigbird_pegasus.py | 39 ++++++++----- .../models/biogpt/modeling_biogpt.py | 39 ++++++++----- .../models/blenderbot/modeling_blenderbot.py | 39 ++++++++----- .../modeling_blenderbot_small.py | 39 ++++++++----- .../data2vec/modeling_data2vec_audio.py | 44 ++++++++------ .../models/hubert/modeling_hubert.py | 39 ++++++++----- .../models/informer/modeling_informer.py | 39 ++++++++----- .../models/m2m_100/modeling_m2m_100.py | 39 ++++++++----- .../models/marian/modeling_marian.py | 39 ++++++++----- .../models/mbart/modeling_mbart.py | 39 ++++++++----- .../models/musicgen/modeling_musicgen.py | 39 ++++++++----- .../modeling_musicgen_melody.py | 39 ++++++++----- .../models/nllb_moe/modeling_nllb_moe.py | 39 ++++++++----- .../patchtsmixer/modeling_patchtsmixer.py | 39 ++++++++----- .../models/patchtst/modeling_patchtst.py | 39 ++++++++----- .../models/pegasus/modeling_pegasus.py | 39 ++++++++----- .../models/pegasus_x/modeling_pegasus_x.py | 39 ++++++++----- .../models/plbart/modeling_plbart.py | 39 ++++++++----- .../qwen2_5_omni/modeling_qwen2_5_omni.py | 2 - .../qwen2_audio/modeling_qwen2_audio.py | 3 +- src/transformers/models/sew/modeling_sew.py | 39 ++++++++----- .../speech_to_text/modeling_speech_to_text.py | 39 ++++++++----- .../modeling_time_series_transformer.py | 39 ++++++++----- .../models/unispeech/modeling_unispeech.py | 39 ++++++++----- .../unispeech_sat/modeling_unispeech_sat.py | 39 ++++++++----- .../models/wav2vec2/modeling_wav2vec2.py | 39 ++++++++----- 31 files changed, 657 insertions(+), 467 deletions(-) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index a78166ed040b..eecccc6a7dfc 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -1,10 +1,13 @@ -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import torch from ..modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask +from ..utils import logging +logger = logging.get_logger(__name__) + _use_top_left_mask = flash_attn_supports_top_left_mask() @@ -18,8 +21,29 @@ def flash_attention_forward( scaling: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, + output_attentions: bool = False, + head_mask: Optional[torch.Tensor] = None, + eager_fallback: Optional[Callable] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: + if output_attentions or head_mask is not None: + logger.warning_once( + "Falling back to eager attention because `flash_attention_2` does not support" + " `output_attentions=True` or `head_mask`." + ) + return eager_fallback( + module, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + dropout=dropout, + scaling=scaling, + output_attentions=output_attentions, + head_mask=head_mask, + **kwargs, + ) + # This is before the transpose seq_len = query.shape[2] diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index ea2bc1b12ee5..8fb5c659884a 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -26,12 +26,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from packaging import version -from ..utils import is_torch_flex_attn_available +from ..utils import is_torch_flex_attn_available, logging from ..utils.import_utils import _torch_version @@ -39,6 +39,9 @@ from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention +logger = logging.get_logger(__name__) + + class WrappedFlexAttention: """ We are doing a singleton class so that flex attention is compiled once when it's first called. @@ -227,8 +230,29 @@ def flex_attention_forward( scaling: Optional[float] = None, softcap: Optional[float] = None, head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + dropout: float = 0.0, + eager_fallback: Optional[Callable] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: + if output_attentions or head_mask is not None or dropout > 0: + logger.warning_once( + "Falling back to eager attention because `flex_attention` does not support" + " `output_attentions=True`, `head_mask`, or `dropout`." + ) + return eager_fallback( + module, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + dropout=dropout, + scaling=scaling, + output_attentions=output_attentions, + head_mask=head_mask, + **kwargs, + ) + block_mask = None score_mask = None if isinstance(attention_mask, BlockMask): diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 515eef5ae988..9cb383ceeee9 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -1,7 +1,12 @@ -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple import torch +from ..utils import logging + + +logger = logging.get_logger(__name__) + def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: """ @@ -24,8 +29,28 @@ def sdpa_attention_forward( dropout: float = 0.0, scaling: Optional[float] = None, is_causal: Optional[bool] = None, + output_attentions: bool = False, + head_mask: Optional[torch.Tensor] = None, + eager_fallback: Optional[Callable] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: + if output_attentions or head_mask is not None: + logger.warning_once( + "Falling back to eager attention because `sdpa` does not support `output_attentions=True` or `head_mask`." + ) + return eager_fallback( + module, + query=query, + key=key, + value=value, + attention_mask=attention_mask, + dropout=dropout, + scaling=scaling, + output_attentions=output_attentions, + head_mask=head_mask, + **kwargs, + ) + if hasattr(module, "num_key_value_groups"): key = repeat_kv(key, module.num_key_value_groups) value = repeat_kv(value, module.num_key_value_groups) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 681f667c0d40..f31db99b246f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -165,10 +165,6 @@ from kernels import get_kernel -if is_torch_flex_attn_available(): - from .integrations.flex_attention import BlockMask - - logger = logging.get_logger(__name__) @@ -6139,60 +6135,6 @@ class AttentionInterface(MutableMapping): def __init__(self): self._local_mapping = {} - def __call__( - self, - attention_type: str, - eager_attention: Callable, - module: nn.Module, - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - attention_mask: Union[torch.Tensor, "BlockMask"], - training: bool = False, - dropout: Optional[float] = 0.0, - scaling: Optional[float] = None, - output_attentions: bool = False, - layer_head_mask: Optional[torch.Tensor] = None, - **kwargs, - ): - attention_interface: Callable = eager_attention - if attention_type != "eager": - if (output_attentions or layer_head_mask is not None) and attention_type in [ - "sdpa", - "flash_attention_2", - "flex_attention", - ]: - logger.warning_once( - f"Falling back to eager attention because `{attention_type}` does not support" - f" `output_attentions=True` or `head_mask`." - ) - elif training and dropout > 0 and attention_type == "flex_attention": - logger.warning_once( - f"Falling back to eager attention because `dropout` is not supported in `{attention_type}`." - ) - else: - attention_interface = self[attention_type] - - if scaling is None and attention_type == "eager": - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) - scaling = query.size(-1) ** -0.5 - - return attention_interface( - module, - query, - key, - value, - attention_mask, - dropout=0.0 if not training else dropout, - scaling=scaling, - layer_head_mask=layer_head_mask, - **kwargs, - ) - def __getitem__(self, key): # First check if instance has a local override if key in self._local_mapping: diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index a0d67b37a262..046c2fc6317e 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -17,7 +17,7 @@ import copy import math import warnings -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -120,19 +120,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -244,19 +252,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 3a421cd5dbfc..e0bd7321d8b9 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -1182,19 +1182,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -1307,19 +1315,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 0d48f82367c9..a291ff6a318a 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -21,7 +21,7 @@ import math from functools import partial -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -98,19 +98,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -222,19 +230,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 13a03a334e41..a8f40da4e372 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -18,7 +18,7 @@ import math import os import warnings -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -118,19 +118,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -243,19 +251,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 442697e9fce7..8cbad2653c07 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -102,19 +102,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -227,19 +235,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 96ccf06ee791..53272da2aff9 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -21,7 +21,7 @@ import math import warnings -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -43,7 +43,7 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available +from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging from .configuration_data2vec_audio import Data2VecAudioConfig @@ -51,6 +51,9 @@ from ...integrations.flex_attention import make_flex_block_causal_mask +logger = logging.get_logger(__name__) + + class Data2VecAudioConvLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -187,19 +190,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -307,19 +318,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 0c8957e9e5d5..85f69c34c7c1 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -20,7 +20,7 @@ # limitations under the License. import warnings -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -247,19 +247,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -367,19 +375,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index f042be7ad33a..fdef80d93972 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -19,7 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -389,19 +389,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -513,19 +521,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index f4c71621def4..09e062961d9c 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -15,7 +15,7 @@ """PyTorch M2M100 model.""" import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -183,19 +183,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -308,19 +316,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 0049e8faa0b9..ae02ee003dd6 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -118,19 +118,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -243,19 +251,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index edaaa5897768..8b053be6672d 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.utils.checkpoint @@ -128,19 +128,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -253,19 +261,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 4a9f859574bf..8207c707831a 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -19,7 +19,7 @@ import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -157,19 +157,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -278,19 +286,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 568f72b9cc3f..80c3b28b0911 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -19,7 +19,7 @@ import math import random from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -173,19 +173,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -294,19 +302,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index df705764d11e..91742d473b21 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -15,7 +15,7 @@ """PyTorch NLLB-MoE model.""" import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -478,19 +478,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -599,19 +607,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 18ddd1760fa6..adb8a8c54de5 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch import torch.nn as nn @@ -244,19 +244,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -365,19 +373,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 456fe3364151..f0c4aa8f6d21 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch import nn @@ -41,19 +41,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -162,19 +170,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 879045ab1e5a..ba095425169f 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -16,7 +16,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -117,19 +117,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -242,19 +250,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 6ad777faae9e..9d44fc3f9a06 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -16,7 +16,7 @@ import dataclasses import math -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -143,19 +143,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -268,19 +276,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index a7a27dee5e12..ddd915a980d6 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -21,7 +21,7 @@ import copy import math -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch from torch import nn @@ -333,19 +333,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -457,19 +465,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index f47ae5c33e61..92be4356e81d 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -738,8 +738,6 @@ def forward( } -# (BC Dep) Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen25OmniAudio, WHISPER->Qwen25OmniAudio -# TODO(vasqu): fix copies when enabling whisper attn interface class Qwen2_5OmniAudioEncoderLayer(nn.Module): def __init__(self, config: Qwen2_5OmniAudioEncoderConfig): super().__init__() diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index c557eb805424..e4208fc8cfd4 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -360,8 +360,7 @@ def forward( } -# (BC Dep) Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO -# TODO(vasqu): fix copies when enabling whisper attn interface +# Copied from transformers.models.whisper.modeling_whisper.WhisperEncoderLayer with Whisper->Qwen2Audio, WHISPER->QWEN2AUDIO class Qwen2AudioEncoderLayer(nn.Module): def __init__(self, config: Qwen2AudioConfig): super().__init__() diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index f6de6142b417..fcaee74cc453 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -21,7 +21,7 @@ import math import warnings -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -240,19 +240,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -360,19 +368,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 48786f4b7f7d..17ab83208729 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -15,7 +15,7 @@ """PyTorch Speech2Text model.""" import math -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import torch from torch import nn @@ -180,19 +180,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -301,19 +309,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index b9ef51c62189..895a6e77c0e4 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -15,7 +15,7 @@ # limitations under the License. """PyTorch Time Series Transformer model.""" -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import numpy as np import torch @@ -283,19 +283,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -408,19 +416,18 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index d100bbc7c7d5..48d7a6d5dae2 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -22,7 +22,7 @@ import math import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -286,19 +286,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -406,19 +414,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index d35d3ae25f8a..703022fe8c38 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -22,7 +22,7 @@ import math import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -289,19 +289,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -409,19 +417,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index e4811c3213d4..cda69a5d8d16 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -17,7 +17,7 @@ import math import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import Callable, Optional, Tuple, Union import numpy as np import torch @@ -478,19 +478,27 @@ def eager_attn_forward( key: torch.Tensor, value: torch.Tensor, attention_mask: Optional[torch.Tensor], - scaling: float, + scaling: Optional[float] = None, dropout: float = 0.0, - layer_head_mask: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, **kwargs, ): + if scaling is None: + logger.warning_once( + "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." + " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" + " repo: https://github.com/huggingface/transformers/issues/new" + ) + scaling = query.size(-1) ** -0.5 + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = attn_weights + attention_mask attn_weights = nn.functional.softmax(attn_weights, dim=-1) - if layer_head_mask is not None: - attn_weights = attn_weights * layer_head_mask.view(1, -1, 1, 1) + if head_mask is not None: + attn_weights = attn_weights * head_mask.view(1, -1, 1, 1) attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training) attn_output = torch.matmul(attn_weights, value) @@ -598,19 +606,18 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attn_output, attn_weights = ALL_ATTENTION_FUNCTIONS( - attention_type=self.config._attn_implementation, - eager_attention=eager_attn_forward, - module=self, - query=query_states, - key=key_states, - value=value_states, - attention_mask=attention_mask, - training=self.training, - dropout=self.dropout, + attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.dropout, scaling=self.scaling, output_attentions=output_attentions, - layer_head_mask=layer_head_mask, + head_mask=layer_head_mask, + eager_fallback=eager_attn_forward, **kwargs, ) From 63e38fa5254bb55816a4f5745ef809393a7b08e8 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 21 May 2025 11:13:53 +0200 Subject: [PATCH 58/68] fix --- src/transformers/models/bart/modeling_bart.py | 4 +++- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 4 +++- src/transformers/models/biogpt/modeling_biogpt.py | 4 +++- src/transformers/models/blenderbot/modeling_blenderbot.py | 4 +++- .../models/blenderbot_small/modeling_blenderbot_small.py | 4 +++- src/transformers/models/data2vec/modeling_data2vec_audio.py | 4 +++- src/transformers/models/hubert/modeling_hubert.py | 4 +++- src/transformers/models/informer/modeling_informer.py | 4 +++- src/transformers/models/m2m_100/modeling_m2m_100.py | 4 +++- src/transformers/models/marian/modeling_marian.py | 4 +++- src/transformers/models/mbart/modeling_mbart.py | 4 +++- src/transformers/models/musicgen/modeling_musicgen.py | 4 +++- .../models/musicgen_melody/modeling_musicgen_melody.py | 4 +++- src/transformers/models/nllb_moe/modeling_nllb_moe.py | 4 +++- src/transformers/models/patchtsmixer/modeling_patchtsmixer.py | 4 +++- src/transformers/models/patchtst/modeling_patchtst.py | 4 +++- src/transformers/models/pegasus/modeling_pegasus.py | 4 +++- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 4 +++- src/transformers/models/plbart/modeling_plbart.py | 4 +++- src/transformers/models/sew/modeling_sew.py | 4 +++- .../models/speech_to_text/modeling_speech_to_text.py | 4 +++- .../modeling_time_series_transformer.py | 4 +++- src/transformers/models/unispeech/modeling_unispeech.py | 4 +++- .../models/unispeech_sat/modeling_unispeech_sat.py | 4 +++- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 4 +++- 25 files changed, 75 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 046c2fc6317e..381a293cd8d5 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -252,7 +252,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index e0bd7321d8b9..ed8cd53877db 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1315,7 +1315,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index a291ff6a318a..becf55814a3c 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -230,7 +230,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index a8f40da4e372..201b1b4de57b 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -251,7 +251,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 8cbad2653c07..719e40686260 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -235,7 +235,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 53272da2aff9..66687f816052 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -318,7 +318,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 85f69c34c7c1..3644cd846bc9 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -375,7 +375,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index fdef80d93972..2856ea645a38 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -521,7 +521,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 09e062961d9c..71633a5e19ea 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -316,7 +316,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index ae02ee003dd6..f36bab107c56 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -251,7 +251,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 8b053be6672d..ff9b19a34b02 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -261,7 +261,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 8207c707831a..3b490f8d2f21 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -286,7 +286,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 80c3b28b0911..1f325cb1456c 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -302,7 +302,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 91742d473b21..7569690f6e48 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -607,7 +607,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index adb8a8c54de5..8f47b41fba4c 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -373,7 +373,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index f0c4aa8f6d21..588e34517faa 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -170,7 +170,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index ba095425169f..f55d2e25251e 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -250,7 +250,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 9d44fc3f9a06..b07e9c154150 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -276,7 +276,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index ddd915a980d6..616ac5be7daf 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -465,7 +465,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index fcaee74cc453..d6e95804faed 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -368,7 +368,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 17ab83208729..af5e7e311c7f 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -309,7 +309,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 895a6e77c0e4..1da45dfb5717 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -416,7 +416,9 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 48d7a6d5dae2..7dbae20fed13 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -414,7 +414,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 703022fe8c38..6552407fa1f2 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -417,7 +417,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index cda69a5d8d16..478ea1a30986 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -606,7 +606,9 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attention_interface: Callable = ( + eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + ) attn_output, attn_weights = attention_interface( self, query_states, From c8e10d16c1e37da205188232c3658a1e15857473 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 21 May 2025 11:27:03 +0200 Subject: [PATCH 59/68] ternary is ugly af, make it simpler --- src/transformers/models/bart/modeling_bart.py | 8 +++++--- .../models/bigbird_pegasus/modeling_bigbird_pegasus.py | 7 ++++--- src/transformers/models/biogpt/modeling_biogpt.py | 7 ++++--- src/transformers/models/blenderbot/modeling_blenderbot.py | 7 ++++--- .../models/blenderbot_small/modeling_blenderbot_small.py | 7 ++++--- .../models/data2vec/modeling_data2vec_audio.py | 7 ++++--- src/transformers/models/hubert/modeling_hubert.py | 7 ++++--- src/transformers/models/informer/modeling_informer.py | 7 ++++--- src/transformers/models/m2m_100/modeling_m2m_100.py | 7 ++++--- src/transformers/models/marian/modeling_marian.py | 7 ++++--- src/transformers/models/mbart/modeling_mbart.py | 7 ++++--- src/transformers/models/musicgen/modeling_musicgen.py | 7 ++++--- .../models/musicgen_melody/modeling_musicgen_melody.py | 7 ++++--- src/transformers/models/nllb_moe/modeling_nllb_moe.py | 7 ++++--- .../models/patchtsmixer/modeling_patchtsmixer.py | 7 ++++--- src/transformers/models/patchtst/modeling_patchtst.py | 7 ++++--- src/transformers/models/pegasus/modeling_pegasus.py | 7 ++++--- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 7 ++++--- src/transformers/models/plbart/modeling_plbart.py | 7 ++++--- src/transformers/models/sew/modeling_sew.py | 7 ++++--- .../models/speech_to_text/modeling_speech_to_text.py | 7 ++++--- .../modeling_time_series_transformer.py | 7 ++++--- src/transformers/models/unispeech/modeling_unispeech.py | 7 ++++--- .../models/unispeech_sat/modeling_unispeech_sat.py | 7 ++++--- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 7 ++++--- 25 files changed, 101 insertions(+), 75 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 381a293cd8d5..d95fdf3cd95b 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -252,9 +252,11 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index ed8cd53877db..20ac613db9d5 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1315,9 +1315,10 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index becf55814a3c..27eddcf716e0 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -230,9 +230,10 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 201b1b4de57b..567d42eba4ed 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -251,9 +251,10 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 719e40686260..6b87980698d9 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -235,9 +235,10 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 66687f816052..b88514b5aecd 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -318,9 +318,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 3644cd846bc9..e1d174c070ee 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -375,9 +375,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 2856ea645a38..72562a548cd4 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -521,9 +521,10 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 71633a5e19ea..407365e00f46 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -316,9 +316,10 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index f36bab107c56..1fc7944170c8 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -251,9 +251,10 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index ff9b19a34b02..5dc0c457bb3b 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -261,9 +261,10 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 3b490f8d2f21..d5bef01a34c7 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -286,9 +286,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 1f325cb1456c..00f7a48239fe 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -302,9 +302,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 7569690f6e48..d724ca032c67 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -607,9 +607,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 8f47b41fba4c..11beeee5b39d 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -373,9 +373,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 588e34517faa..0f40f1cded14 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -170,9 +170,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index f55d2e25251e..6d2f77513985 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -250,9 +250,10 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index b07e9c154150..9a230c9e6911 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -276,9 +276,10 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 616ac5be7daf..96dd97a199ab 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -465,9 +465,10 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index d6e95804faed..a41dd35807d0 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -368,9 +368,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index af5e7e311c7f..2680d74978ef 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -309,9 +309,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 1da45dfb5717..077b0a7c2f57 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -416,9 +416,10 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 7dbae20fed13..5811e88a56f9 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -414,9 +414,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 6552407fa1f2..2102cae53cd1 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -417,9 +417,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 478ea1a30986..9de46dff4693 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -606,9 +606,10 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = ( - eager_attn_forward if "eager" else ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] - ) + attention_interface: Callable = eager_attn_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( self, query_states, From 47425837b4f00de9d5135ce78714da2f8572ec2c Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 21 May 2025 11:29:37 +0200 Subject: [PATCH 60/68] how did that happen --- src/transformers/models/bart/modeling_bart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index d95fdf3cd95b..bdeb4013805a 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -252,7 +252,6 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] From b43f3fd95a066f7d7a598a6e4745392f08b7f0c7 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 21 May 2025 13:28:02 +0200 Subject: [PATCH 61/68] fix flex attn test --- tests/test_modeling_common.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6f79bde7649e..5c1a97a7eb7a 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4517,8 +4517,8 @@ def test_flex_attention_with_grads(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config._attn_implementation = "flex_attention" # Flex Attention can not use dropout - if hasattr(config, "attention_droput"): - config.attention_droput = 0 + if hasattr(config, "attention_dropout"): + config.attention_dropout = 0 if hasattr(config, "attention_probs_dropout_prob"): config.attention_probs_dropout_prob = 0 @@ -4526,15 +4526,13 @@ def test_flex_attention_with_grads(self): self.assertTrue(model.config._attn_implementation == "flex_attention") # Elaborate workaround for encoder-decoder models as some do not specify their main input - if "input_ids" in inspect.signature(model.forward).parameters: - dummy_input = {"input_ids": inputs_dict[model_class.main_input_name].to(torch_device)} - if "decoder_input_ids" in inspect.signature(model.forward).parameters: - dummy_input["decoder_input_ids"] = dummy_input["input_ids"].clone() - else: - dummy_input = {model_class.main_input_name: inputs_dict[model_class.main_input_name].to(torch_device)} + dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)} + if config.is_encoder_decoder: + dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"] + dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"] # If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605) - _ = model(**dummy_input) + _ = model(**dummy_inputs) def test_generation_tester_mixin_inheritance(self): """ From e8a9139338b2cf32b02c04641b0e81d2972d5864 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Wed, 21 May 2025 14:12:25 +0200 Subject: [PATCH 62/68] failing the test --- src/transformers/models/bart/modeling_bart.py | 3 ++- src/transformers/models/blenderbot/modeling_blenderbot.py | 3 ++- .../models/blenderbot_small/modeling_blenderbot_small.py | 3 ++- src/transformers/models/data2vec/modeling_data2vec_audio.py | 3 ++- src/transformers/models/data2vec/modular_data2vec_audio.py | 3 ++- 5 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index bdeb4013805a..1cd0ec9c25c8 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -500,7 +500,8 @@ class BartPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # TODO: compilation issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 567d42eba4ed..970d50f1dc61 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -470,7 +470,8 @@ class BlenderbotPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # TODO: compilation issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 6b87980698d9..936542f12ad3 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -458,7 +458,8 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # TODO: compilation issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index b88514b5aecd..1491234d49c3 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -570,7 +570,8 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # TODO: compilation issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py index 0b4695c1e28c..51dcc85c63bf 100644 --- a/src/transformers/models/data2vec/modular_data2vec_audio.py +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -140,7 +140,8 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # TODO: compilation issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" From ab0754f7c946a726d83b8561f6c2d120fdf6ba7f Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 22 May 2025 12:21:04 +0200 Subject: [PATCH 63/68] no more fallback! fixing copies next --- src/transformers/models/bart/modeling_bart.py | 59 ++++--------------- tests/generation/test_utils.py | 16 +++++ tests/test_modeling_common.py | 8 ++- 3 files changed, 34 insertions(+), 49 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 1cd0ec9c25c8..825e96677de6 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -114,7 +114,7 @@ def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale -def eager_attn_forward( +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -126,11 +126,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -252,7 +247,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -266,7 +261,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -533,21 +527,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -562,14 +551,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -583,7 +566,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -595,7 +578,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -629,7 +612,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -701,14 +683,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, - dropout: float = 0.0, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -717,11 +697,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, @@ -871,14 +847,9 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, inputs_embeds, - _unsupported_features, ) encoder_states = () if output_hidden_states else None @@ -1121,25 +1092,17 @@ def forward( else past_key_values ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - _unsupported_features, - self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, - self.config.attention_dropout, ) # embed positions diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index 783e7cb3a7a7..57037ee435c7 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -1148,6 +1148,10 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type): # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) + # force eager attention to support output attentions + if self.has_attentions: + config._attn_implementation = "eager" + # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") @@ -1228,6 +1232,10 @@ def test_prompt_lookup_decoding_matches_greedy_search(self): # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) + # force eager attention to support output attentions + if self.has_attentions: + config._attn_implementation = "eager" + # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") @@ -1282,6 +1290,10 @@ def test_dola_decoding_sample(self): # enable cache if the model is not openai-gpt, xlnet, cpm, or xlm config, inputs_dict = self.prepare_config_and_inputs_for_generate() + # force eager attention to support output attentions + if self.has_attentions: + config._attn_implementation = "eager" + # Encoder-decoder models are not supported if config.is_encoder_decoder: self.skipTest("DoLa is not supported for encoder-decoder models") @@ -1346,6 +1358,10 @@ def test_assisted_decoding_sample(self): # enable cache config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1) + # force eager attention to support output attentions + if self.has_attentions: + config._attn_implementation = "eager" + # NOTE: assisted generation only works with cache on at the moment. if not hasattr(config.get_text_config(), "use_cache"): self.skipTest(reason=f"{model_class.__name__} doesn't support caching") diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 5c1a97a7eb7a..ed98c931e5b3 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -958,6 +958,8 @@ def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.return_dict = True + # force eager attention to support output attentions + config._attn_implementation = "eager" seq_len = getattr(self.model_tester, "seq_length", None) decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) @@ -1106,7 +1108,7 @@ def _create_and_check_torchscript(self, config, inputs_dict): configs_no_init.torchscript = True for model_class in self.all_model_classes: for attn_implementation in ["eager", "sdpa"]: - if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()): + if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()) or config.output_attentions: continue configs_no_init._attn_implementation = attn_implementation @@ -1723,6 +1725,10 @@ def test_retain_grad_hidden_states_attentions(self): config.output_hidden_states = True config.output_attentions = self.has_attentions + # force eager attention to support output attentions + if self.has_attentions: + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) From c7c1499772c33c466d021cfa00613315644fc3ab Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 22 May 2025 12:21:44 +0200 Subject: [PATCH 64/68] style + attn fixed --- .../integrations/flash_attention.py | 23 +++------------- .../integrations/flex_attention.py | 27 +++++++------------ .../integrations/sdpa_attention.py | 22 +++------------ tests/test_modeling_common.py | 6 ++++- 4 files changed, 22 insertions(+), 56 deletions(-) diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py index eecccc6a7dfc..4f76e65a847b 100644 --- a/src/transformers/integrations/flash_attention.py +++ b/src/transformers/integrations/flash_attention.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Tuple +from typing import Optional, Tuple import torch @@ -21,27 +21,12 @@ def flash_attention_forward( scaling: Optional[float] = None, sliding_window: Optional[int] = None, softcap: Optional[float] = None, - output_attentions: bool = False, - head_mask: Optional[torch.Tensor] = None, - eager_fallback: Optional[Callable] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: - if output_attentions or head_mask is not None: + if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None: logger.warning_once( - "Falling back to eager attention because `flash_attention_2` does not support" - " `output_attentions=True` or `head_mask`." - ) - return eager_fallback( - module, - query=query, - key=key, - value=value, - attention_mask=attention_mask, - dropout=dropout, - scaling=scaling, - output_attentions=output_attentions, - head_mask=head_mask, - **kwargs, + "`flash_attention_2` does not support `output_attentions=True` or `head_mask`." + " Please set your attention to `eager` if you want any of these features." ) # This is before the transpose diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py index 8fb5c659884a..b6a41400de48 100644 --- a/src/transformers/integrations/flex_attention.py +++ b/src/transformers/integrations/flex_attention.py @@ -26,7 +26,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Callable, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from packaging import version @@ -230,27 +230,18 @@ def flex_attention_forward( scaling: Optional[float] = None, softcap: Optional[float] = None, head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - dropout: float = 0.0, - eager_fallback: Optional[Callable] = None, **kwargs, ) -> Tuple[torch.Tensor, torch.Tensor]: - if output_attentions or head_mask is not None or dropout > 0: + if kwargs.get("output_attentions", False) or head_mask is not None: logger.warning_once( - "Falling back to eager attention because `flex_attention` does not support" - " `output_attentions=True`, `head_mask`, or `dropout`." + "`flex_attention` does not support `output_attentions=True` or `head_mask`." + " Please set your attention to `eager` if you want any of these features." ) - return eager_fallback( - module, - query=query, - key=key, - value=value, - attention_mask=attention_mask, - dropout=dropout, - scaling=scaling, - output_attentions=output_attentions, - head_mask=head_mask, - **kwargs, + + if kwargs.get("dropout", 0.0) > 0: + raise ValueError( + "`flex_attention` does not support `dropout`. Please use it with inference" + " only (`model.eval()`) or turn off the attention dropout in the respective config." ) block_mask = None diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py index 9cb383ceeee9..247cd2821679 100644 --- a/src/transformers/integrations/sdpa_attention.py +++ b/src/transformers/integrations/sdpa_attention.py @@ -1,4 +1,4 @@ -from typing import Callable, Optional, Tuple +from typing import Optional, Tuple import torch @@ -29,26 +29,12 @@ def sdpa_attention_forward( dropout: float = 0.0, scaling: Optional[float] = None, is_causal: Optional[bool] = None, - output_attentions: bool = False, - head_mask: Optional[torch.Tensor] = None, - eager_fallback: Optional[Callable] = None, **kwargs, ) -> Tuple[torch.Tensor, None]: - if output_attentions or head_mask is not None: + if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None: logger.warning_once( - "Falling back to eager attention because `sdpa` does not support `output_attentions=True` or `head_mask`." - ) - return eager_fallback( - module, - query=query, - key=key, - value=value, - attention_mask=attention_mask, - dropout=dropout, - scaling=scaling, - output_attentions=output_attentions, - head_mask=head_mask, - **kwargs, + "`sdpa` attention does not support `output_attentions=True` or `head_mask`." + " Please set your attention to `eager` if you want any of these features." ) if hasattr(module, "num_key_value_groups"): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index ed98c931e5b3..7403e95ed61d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1108,7 +1108,11 @@ def _create_and_check_torchscript(self, config, inputs_dict): configs_no_init.torchscript = True for model_class in self.all_model_classes: for attn_implementation in ["eager", "sdpa"]: - if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()) or config.output_attentions: + if ( + attn_implementation == "sdpa" + and (not model_class._supports_sdpa or not is_torch_sdpa_available()) + or config.output_attentions + ): continue configs_no_init._attn_implementation = attn_implementation From e62a8aceeab3c88a0d2fb23f71d979f2aa0b0998 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 22 May 2025 12:49:18 +0200 Subject: [PATCH 65/68] fixing copies and mask creation --- .../models/autoformer/modeling_autoformer.py | 16 +---- .../modeling_bigbird_pegasus.py | 45 +++----------- .../models/biogpt/modeling_biogpt.py | 28 ++------- .../models/biogpt/modular_biogpt.py | 18 +----- .../models/blenderbot/modeling_blenderbot.py | 61 ++++--------------- .../modeling_blenderbot_small.py | 61 ++++--------------- .../data2vec/modeling_data2vec_audio.py | 31 ++-------- .../models/hubert/modeling_hubert.py | 42 +++---------- .../models/informer/modeling_informer.py | 49 ++++----------- .../models/informer/modular_informer.py | 33 +++------- .../models/m2m_100/modeling_m2m_100.py | 61 ++++--------------- .../models/marian/modeling_marian.py | 49 +++------------ .../models/mbart/modeling_mbart.py | 61 ++++--------------- .../models/musicgen/modeling_musicgen.py | 40 +++--------- .../modeling_musicgen_melody.py | 29 ++------- .../models/nllb_moe/modeling_nllb_moe.py | 56 ++++------------- .../patchtsmixer/modeling_patchtsmixer.py | 12 +--- .../models/patchtst/modeling_patchtst.py | 12 +--- .../models/pegasus/modeling_pegasus.py | 61 ++++--------------- .../models/pegasus_x/modeling_pegasus_x.py | 56 ++++------------- .../models/plbart/modeling_plbart.py | 59 ++++-------------- .../models/plbart/modular_plbart.py | 36 +++-------- src/transformers/models/sew/modeling_sew.py | 10 +-- .../speech_to_text/modeling_speech_to_text.py | 56 ++++------------- .../modeling_time_series_transformer.py | 56 ++++------------- .../models/unispeech/modeling_unispeech.py | 42 +++---------- .../unispeech_sat/modeling_unispeech_sat.py | 42 +++---------- .../models/wav2vec2/modeling_wav2vec2.py | 44 +++---------- 28 files changed, 239 insertions(+), 927 deletions(-) diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index ed83b9eefe9d..0a41692f69cd 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -916,21 +916,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -1018,14 +1013,9 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states + embed_pos) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, inputs_embeds, - _unsupported_features, ) encoder_states = () if output_hidden_states else None diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 20ac613db9d5..a6423578ef1f 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -1175,8 +1175,8 @@ def forward( return outputs -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -1188,11 +1188,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -1315,7 +1310,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -1329,7 +1324,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -1612,14 +1606,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -1633,7 +1621,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -1645,7 +1633,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -1679,7 +1667,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -1752,14 +1739,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, - dropout: float = 0.0, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -1768,11 +1753,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, @@ -2272,25 +2253,17 @@ def forward( else past_key_values ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - _unsupported_features, - self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, - self.config.attention_dropout, ) # embed positions diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 27eddcf716e0..d8019dfa6215 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -92,7 +92,7 @@ def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale -def eager_attn_forward( +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -104,11 +104,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -230,7 +225,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -244,7 +239,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -382,14 +376,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -403,7 +391,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -415,7 +403,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -449,7 +437,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -620,16 +607,11 @@ def forward( else past_key_values ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, - self.config.attention_probs_dropout_prob, ) # embed positions diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index 08c1a16218aa..ce3de6cabe98 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -203,14 +203,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -224,7 +218,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -236,7 +230,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -270,7 +264,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -441,16 +434,11 @@ def forward( else past_key_values ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - output_attentions, - self.config.attention_probs_dropout_prob, ) # embed positions diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 970d50f1dc61..e4d8c35ba6ef 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -111,8 +111,8 @@ def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -124,11 +124,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -251,7 +246,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -265,7 +260,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -505,21 +499,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -535,14 +524,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -556,7 +539,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -568,7 +551,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -602,7 +585,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -675,14 +657,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, - dropout: float = 0.0, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -691,11 +671,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, @@ -822,14 +798,9 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, inputs_embeds, - _unsupported_features, ) encoder_states = () if output_hidden_states else None @@ -1076,25 +1047,17 @@ def forward( else past_key_values ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - _unsupported_features, - self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, - self.config.attention_dropout, ) # embed positions diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 936542f12ad3..33d9a11557f9 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -95,8 +95,8 @@ def forward( return super().forward(position_ids) -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -108,11 +108,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -235,7 +230,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -249,7 +244,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -493,21 +487,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -523,14 +512,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -544,7 +527,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -556,7 +539,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -590,7 +573,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -663,14 +645,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, - dropout: float = 0.0, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -679,11 +659,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, @@ -809,14 +785,9 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, inputs_embeds, - _unsupported_features, ) encoder_states = () if output_hidden_states else None @@ -1057,25 +1028,17 @@ def forward( else past_key_values ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - _unsupported_features, - self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, - self.config.attention_dropout, ) # embed positions diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index 1491234d49c3..affdec3dcc75 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -43,7 +43,7 @@ ) from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging +from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available from .configuration_data2vec_audio import Data2VecAudioConfig @@ -51,9 +51,6 @@ from ...integrations.flex_attention import make_flex_block_causal_mask -logger = logging.get_logger(__name__) - - class Data2VecAudioConvLayer(nn.Module): def __init__(self, config, layer_id=0): super().__init__() @@ -184,7 +181,7 @@ def forward(self, hidden_states): return hidden_states, norm_hidden_states -def eager_attn_forward( +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -196,11 +193,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -318,7 +310,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -332,7 +324,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -428,14 +419,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, hidden_states, - _unsupported_features, ) position_embeddings = self.pos_conv_embed(hidden_states) @@ -489,21 +475,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index e1d174c070ee..115345407e69 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -241,7 +241,7 @@ def forward(self, hidden_states): return hidden_states -def eager_attn_forward( +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -253,11 +253,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -375,7 +370,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -389,7 +384,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -485,14 +479,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, hidden_states, - _unsupported_features, ) position_embeddings = self.pos_conv_embed(hidden_states) @@ -546,21 +535,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -669,14 +653,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, hidden_states, - _unsupported_features, ) position_embeddings = self.pos_conv_embed(hidden_states) @@ -732,21 +711,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 72562a548cd4..330bc620bc03 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -273,21 +273,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -303,12 +298,11 @@ def _update_causal_mask( input_shape: torch.Size, inputs_embeds: torch.Tensor, past_key_values_length: int, - _unsupported_features: bool, ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -317,11 +311,7 @@ def _update_causal_mask( inputs_embeds, past_key_values_length, ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -348,13 +338,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -363,11 +352,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, @@ -383,7 +368,7 @@ def _update_cross_attn_mask( return encoder_attention_mask -def eager_attn_forward( +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -395,11 +380,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -521,7 +501,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -535,7 +515,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -1271,23 +1250,17 @@ def forward( past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length, - _unsupported_features, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, ) hidden_states = self.value_embedding(inputs_embeds) diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py index cec23ba220ed..15bcb8d38a83 100644 --- a/src/transformers/models/informer/modular_informer.py +++ b/src/transformers/models/informer/modular_informer.py @@ -114,21 +114,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -144,12 +139,11 @@ def _update_causal_mask( input_shape: torch.Size, inputs_embeds: torch.Tensor, past_key_values_length: int, - _unsupported_features: bool, ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -158,11 +152,7 @@ def _update_causal_mask( inputs_embeds, past_key_values_length, ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -189,13 +179,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -204,11 +193,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 407365e00f46..f2cf438041f2 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -176,8 +176,8 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_ return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -189,11 +189,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -316,7 +311,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -330,7 +325,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -560,21 +554,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -590,14 +579,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -611,7 +594,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -623,7 +606,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -657,7 +640,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -730,14 +712,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, - dropout: float = 0.0, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -746,11 +726,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, @@ -879,14 +855,9 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, inputs_embeds, - _unsupported_features, ) encoder_states = () if output_hidden_states else None @@ -1125,25 +1096,17 @@ def forward( else past_key_values ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - _unsupported_features, - self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, - self.config.attention_dropout, ) # embed positions positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 1fc7944170c8..394ca56c9702 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -111,8 +111,8 @@ def forward( return super().forward(position_ids) -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -124,11 +124,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -251,7 +246,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -265,7 +260,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -510,21 +504,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -540,14 +529,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -561,7 +544,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -573,7 +556,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -607,7 +590,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -829,14 +811,9 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, inputs_embeds, - _unsupported_features, ) encoder_states = () if output_hidden_states else None @@ -1074,25 +1051,17 @@ def forward( else past_key_values ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - _unsupported_features, - self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, - self.config.attention_dropout, ) # embed positions diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 5dc0c457bb3b..ed3325d4443a 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -121,8 +121,8 @@ def forward(self, input_ids: torch.Tensor): return super().forward(input_ids) * self.embed_scale -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -134,11 +134,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -261,7 +256,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -275,7 +270,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -537,21 +531,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -567,14 +556,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -588,7 +571,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -600,7 +583,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -634,7 +617,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -707,14 +689,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, - dropout: float = 0.0, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -723,11 +703,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, @@ -862,14 +838,9 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, inputs_embeds, - _unsupported_features, ) encoder_states = () if output_hidden_states else None @@ -1115,25 +1086,17 @@ def forward( else past_key_values ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - _unsupported_features, - self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, - self.config.attention_dropout, ) # embed positions diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index d5bef01a34c7..a0e21f586cfc 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -150,8 +150,8 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): return self.weights.index_select(0, position_ids.view(-1)).detach() -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -163,11 +163,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -286,7 +281,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -300,7 +295,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -575,23 +569,17 @@ def forward( if inputs_embeds is None: inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)]) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length, - _unsupported_features, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, ) # embed positions @@ -697,12 +685,11 @@ def _update_causal_mask( input_shape: torch.Size, inputs_embeds: torch.Tensor, past_key_values_length: int, - _unsupported_features: bool, ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -711,11 +698,7 @@ def _update_causal_mask( inputs_embeds, past_key_values_length, ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -741,13 +724,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -756,11 +738,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 00f7a48239fe..3312ad33cdb1 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -166,8 +166,8 @@ def forward(self, inputs_embeds: torch.Tensor, past_key_values_length: int = 0): return self.weights.index_select(0, position_ids.view(-1)).detach() -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -179,11 +179,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -302,7 +297,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -316,7 +311,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -559,16 +553,11 @@ def forward( input_shape = inputs_embeds.size()[:-1] - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length, - _unsupported_features, ) # embed positions @@ -657,12 +646,11 @@ def _update_causal_mask( input_shape: torch.Size, inputs_embeds: torch.Tensor, past_key_values_length: int, - _unsupported_features: bool, ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -671,11 +659,7 @@ def _update_causal_mask( inputs_embeds, past_key_values_length, ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -702,7 +686,6 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): # MusicgenMelody doesn't apply cross attention, hence it's ignored here # and only exists to not confuse any copy checks diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index d724ca032c67..6d7bd6c985d1 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -471,8 +471,8 @@ def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tens return hidden_states, (router_probs, top_1_expert_index) -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -484,11 +484,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -607,7 +602,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -621,7 +616,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -999,14 +993,9 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, inputs_embeds, - _unsupported_features, ) encoder_states = () if output_hidden_states else None @@ -1076,21 +1065,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -1252,23 +1236,17 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length, - _unsupported_features, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, ) # embed positions @@ -1400,12 +1378,11 @@ def _update_causal_mask( input_shape: torch.Size, inputs_embeds: torch.Tensor, past_key_values_length: int, - _unsupported_features: bool, ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -1414,11 +1391,7 @@ def _update_causal_mask( inputs_embeds, past_key_values_length, ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -1445,13 +1418,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -1460,11 +1432,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 11beeee5b39d..8f00e8900928 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -237,8 +237,8 @@ def forward(self, inputs: torch.Tensor): return out -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -250,11 +250,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -373,7 +368,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -387,7 +382,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index 0f40f1cded14..b85e8a66b254 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -34,8 +34,8 @@ logger = logging.get_logger(__name__) -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -47,11 +47,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -170,7 +165,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -184,7 +179,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 6d2f77513985..e65b8f55c7de 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -110,8 +110,8 @@ def forward( return super().forward(position_ids) -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -123,11 +123,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -250,7 +245,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -264,7 +259,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -494,21 +488,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -524,14 +513,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -545,7 +528,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -557,7 +540,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -591,7 +574,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -664,14 +646,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, - dropout: float = 0.0, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -680,11 +660,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, @@ -841,14 +817,9 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, inputs_embeds, - _unsupported_features, ) encoder_states = () if output_hidden_states else None @@ -1123,25 +1094,17 @@ def forward( else past_key_values ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - _unsupported_features, - self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, - self.config.attention_dropout, ) # embed positions diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 9a230c9e6911..7ff113992cfd 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -136,8 +136,8 @@ def forward( return pe[None].expand(batch_size, -1, -1) -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -149,11 +149,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -276,7 +271,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -290,7 +285,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -791,21 +785,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -821,14 +810,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -842,7 +825,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -854,7 +837,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -888,7 +871,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -961,14 +943,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, - dropout: float = 0.0, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -977,11 +957,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, @@ -1377,25 +1353,17 @@ def forward( else past_key_values ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True causal_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - _unsupported_features, - self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, - self.config.attention_dropout, ) # embed positions diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 96dd97a199ab..cbacb54332c7 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -96,21 +96,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -126,14 +121,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -147,7 +136,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -159,7 +148,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -193,7 +182,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -266,14 +254,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, - dropout: float = 0.0, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -282,11 +268,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, @@ -327,7 +309,7 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0, posi return super().forward(position_ids + self.offset) -def eager_attn_forward( +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -339,11 +321,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -465,7 +442,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -479,7 +456,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -678,14 +654,9 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, inputs_embeds, - _unsupported_features, ) encoder_states = () if output_hidden_states else None @@ -1047,25 +1018,17 @@ def forward( else past_key_values ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, inputs_embeds, cache_position, self_attn_cache, - _unsupported_features, - self.config.attention_dropout, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, - self.config.attention_dropout, ) # embed positions diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index 716374a2dc89..095043544d95 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -82,21 +82,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -112,14 +107,8 @@ def _update_causal_mask( input_tensor: torch.Tensor, cache_position: torch.Tensor, past_key_values: Cache, - _unsupported_features: bool = False, - dropout: float = 0.0, ): - if ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + if self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -133,7 +122,7 @@ def _update_causal_mask( ) return attention_mask - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and (attention_mask == 0.0).any(): return attention_mask return None @@ -145,7 +134,7 @@ def _update_causal_mask( using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward - if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not _unsupported_features: + if self.config._attn_implementation == "sdpa" and not using_compilable_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, @@ -179,7 +168,6 @@ def _update_causal_mask( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type in ["cuda", "xpu", "npu"] - and not _unsupported_features ): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. @@ -252,14 +240,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, - dropout: float = 0.0, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -268,11 +254,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index a41dd35807d0..330cd99a7b4e 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -234,7 +234,7 @@ def forward(self, input_values): return hidden_states -def eager_attn_forward( +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -246,11 +246,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -368,7 +363,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -382,7 +377,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index 2680d74978ef..aa4ea8107110 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -173,8 +173,8 @@ def create_position_ids_from_input_ids( return incremental_indices.long() + padding_idx -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -186,11 +186,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -309,7 +304,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -323,7 +318,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -672,14 +666,9 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, inputs_embeds, - _unsupported_features, ) encoder_states = () if output_hidden_states else None @@ -740,21 +729,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -907,23 +891,17 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length, - _unsupported_features, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, ) # embed positions @@ -1028,12 +1006,11 @@ def _update_causal_mask( input_shape: torch.Size, inputs_embeds: torch.Tensor, past_key_values_length: int, - _unsupported_features: bool, ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -1042,11 +1019,7 @@ def _update_causal_mask( inputs_embeds, past_key_values_length, ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -1073,13 +1046,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -1088,11 +1060,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 077b0a7c2f57..dc960efbbcf3 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -276,8 +276,8 @@ def forward(self, x): return self.value_projection(x) -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -289,11 +289,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -416,7 +411,7 @@ def forward( if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -430,7 +425,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -662,21 +656,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -692,12 +681,11 @@ def _update_causal_mask( input_shape: torch.Size, inputs_embeds: torch.Tensor, past_key_values_length: int, - _unsupported_features: bool, ): - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": # 2d mask is passed through the layers attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( @@ -706,11 +694,7 @@ def _update_causal_mask( inputs_embeds, past_key_values_length, ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask) # Other attention flavors support in-built causal (when `mask is None`) @@ -737,13 +721,12 @@ def _update_cross_attn_mask( encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -752,11 +735,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, @@ -849,14 +828,9 @@ def forward( hidden_states = self.layernorm_embedding(hidden_states + embed_pos) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or head_mask is not None attention_mask = self._update_full_mask( attention_mask, inputs_embeds, - _unsupported_features, ) encoder_states = () if output_hidden_states else None @@ -1043,23 +1017,17 @@ def forward( past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device ) - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True or cross_attn_head_mask is not None or head_mask is not None attention_mask = self._update_causal_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length, - _unsupported_features, ) encoder_attention_mask = self._update_cross_attn_mask( encoder_hidden_states, encoder_attention_mask, input_shape, inputs_embeds, - _unsupported_features, ) hidden_states = self.value_embedding(inputs_embeds) diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 5811e88a56f9..4fdce328e9e6 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -280,7 +280,7 @@ def forward(self, hidden_states): return hidden_states, norm_hidden_states -def eager_attn_forward( +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -292,11 +292,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -414,7 +409,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -428,7 +423,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -524,14 +518,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, hidden_states, - _unsupported_features, ) position_embeddings = self.pos_conv_embed(hidden_states) @@ -585,21 +574,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -708,14 +692,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, hidden_states, - _unsupported_features, ) position_embeddings = self.pos_conv_embed(hidden_states) @@ -771,21 +750,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 2102cae53cd1..50ee4c198d25 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -283,7 +283,7 @@ def forward(self, hidden_states): return hidden_states, norm_hidden_states -def eager_attn_forward( +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -295,11 +295,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -417,7 +412,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -431,7 +426,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -527,14 +521,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, hidden_states, - _unsupported_features, ) position_embeddings = self.pos_conv_embed(hidden_states) @@ -588,21 +577,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -711,14 +695,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, hidden_states, - _unsupported_features, ) position_embeddings = self.pos_conv_embed(hidden_states) @@ -774,21 +753,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 9de46dff4693..ae3510f175e0 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -471,8 +471,8 @@ def forward(self, hidden_states): return hidden_states, norm_hidden_states -# Copied from transformers.models.bart.modeling_bart.eager_attn_forward -def eager_attn_forward( +# Copied from transformers.models.bart.modeling_bart.eager_attention_forward +def eager_attention_forward( module: nn.Module, query: torch.Tensor, key: torch.Tensor, @@ -484,11 +484,6 @@ def eager_attn_forward( **kwargs, ): if scaling is None: - logger.warning_once( - "You are using a model's `eager` attention module but are not passing its appropriate attention scaling." - " We default to `head_dim**-0.5`. If this is unexpected, please report this to the Transformers GitHub" - " repo: https://github.com/huggingface/transformers/issues/new" - ) scaling = query.size(-1) ** -0.5 attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling @@ -606,7 +601,7 @@ def forward( # if encoder bi-directional self-attention `past_key_value` is always `None` past_key_value = (key_states, value_states) - attention_interface: Callable = eager_attn_forward + attention_interface: Callable = eager_attention_forward if self.config._attn_implementation != "eager": attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] @@ -620,7 +615,6 @@ def forward( scaling=self.scaling, output_attentions=output_attentions, head_mask=layer_head_mask, - eager_fallback=eager_attn_forward, **kwargs, ) @@ -762,14 +756,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, hidden_states, - _unsupported_features, ) position_embeddings = self.pos_conv_embed(hidden_states) @@ -824,21 +813,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: @@ -876,14 +860,9 @@ def forward( expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2]) hidden_states[~expand_attention_mask] = 0 - # Efficient attention implementations are not able to interact with certain features, - # e.g. outputting the attention weights, applying a head mask, and dropout (flex attention). - # In these cases, we fall back to the eager attention to enable the requested feature(s). - _unsupported_features = output_attentions is True attention_mask = self._update_full_mask( attention_mask, hidden_states, - _unsupported_features, ) position_embeddings = self.pos_conv_embed(hidden_states) @@ -940,21 +919,16 @@ def _update_full_mask( self, attention_mask: Union[torch.Tensor, None], inputs_embeds: torch.Tensor, - _unsupported_features: bool, ): if attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": attention_mask = attention_mask if 0 in attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & head_mask can not be supported when using SDPA, fall back to # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (self.config.attention_dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(attention_mask, torch.Tensor): attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False) else: From cd39964eacbf0145e078148f2007bfeb5b06a341 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 22 May 2025 12:58:56 +0200 Subject: [PATCH 66/68] wrong copy --- src/transformers/models/marian/modeling_marian.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 394ca56c9702..fc9cb3a7ad59 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -655,21 +655,19 @@ def _prepare_4d_causal_attention_mask_with_cache_position( return causal_mask - # Copied from trasformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask + # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask def _update_cross_attn_mask( self, encoder_hidden_states: Union[torch.Tensor, None], encoder_attention_mask: Union[torch.Tensor, None], input_shape: torch.Size, inputs_embeds: torch.Tensor, - _unsupported_features: bool, - dropout: float = 0.0, ): # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: - if self.config._attn_implementation == "flash_attention_2" and not _unsupported_features: + if self.config._attn_implementation == "flash_attention_2": encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None - elif self.config._attn_implementation == "sdpa" and not _unsupported_features: + elif self.config._attn_implementation == "sdpa": # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on # the manual implementation that requires a 4D causal mask in all cases. # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -678,11 +676,7 @@ def _update_cross_attn_mask( inputs_embeds.dtype, tgt_len=input_shape[-1], ) - elif ( - self.config._attn_implementation == "flex_attention" - and not _unsupported_features - and (dropout == 0 or not self.training) - ): + elif self.config._attn_implementation == "flex_attention": if isinstance(encoder_attention_mask, torch.Tensor): encoder_attention_mask = make_flex_block_causal_mask( encoder_attention_mask, From c450a3d28be49b6e742cfb9c3cb950ac4d144262 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 22 May 2025 13:53:44 +0200 Subject: [PATCH 67/68] fixup tests and disable flex attn for now --- src/transformers/models/bart/modeling_bart.py | 2 +- src/transformers/models/biogpt/modeling_biogpt.py | 2 +- src/transformers/models/biogpt/modular_biogpt.py | 2 +- src/transformers/models/blenderbot/modeling_blenderbot.py | 2 +- .../models/blenderbot_small/modeling_blenderbot_small.py | 2 +- src/transformers/models/data2vec/modeling_data2vec_audio.py | 2 +- src/transformers/models/data2vec/modular_data2vec_audio.py | 2 +- src/transformers/models/hubert/modeling_hubert.py | 3 ++- src/transformers/models/hubert/modular_hubert.py | 3 ++- src/transformers/models/m2m_100/modeling_m2m_100.py | 3 ++- src/transformers/models/marian/modeling_marian.py | 3 ++- src/transformers/models/mbart/modeling_mbart.py | 3 ++- src/transformers/models/pegasus/modeling_pegasus.py | 3 ++- src/transformers/models/pegasus_x/modeling_pegasus_x.py | 5 +++-- src/transformers/models/plbart/modeling_plbart.py | 3 ++- src/transformers/models/plbart/modular_plbart.py | 3 ++- .../models/speech_to_text/modeling_speech_to_text.py | 1 + .../modeling_time_series_transformer.py | 1 + src/transformers/models/unispeech/modeling_unispeech.py | 3 ++- src/transformers/models/unispeech/modular_unispeech.py | 3 ++- .../models/unispeech_sat/modeling_unispeech_sat.py | 3 ++- .../models/unispeech_sat/modular_unispeech_sat.py | 3 ++- src/transformers/models/wav2vec2/modeling_wav2vec2.py | 3 ++- tests/models/data2vec/test_modeling_data2vec_audio.py | 3 +++ tests/models/hubert/test_modeling_hubert.py | 6 ++++++ tests/models/musicgen/test_modeling_musicgen.py | 6 ++++++ .../models/musicgen_melody/test_modeling_musicgen_melody.py | 6 ++++++ tests/models/sew/test_modeling_sew.py | 3 +++ tests/models/unispeech/test_modeling_unispeech.py | 3 +++ tests/models/unispeech_sat/test_modeling_unispeech_sat.py | 6 ++++++ tests/models/wav2vec2/test_modeling_wav2vec2.py | 6 ++++++ 31 files changed, 77 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index 825e96677de6..9bbae4ecde78 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -494,7 +494,7 @@ class BartPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True - # TODO: compilation issues + # Compile issues _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index d8019dfa6215..8fbfe244e76f 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -348,7 +348,7 @@ class BioGptPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # compiling issues + # Compile issues _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py index ce3de6cabe98..c8fe5855e294 100644 --- a/src/transformers/models/biogpt/modular_biogpt.py +++ b/src/transformers/models/biogpt/modular_biogpt.py @@ -175,7 +175,7 @@ class BioGptPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # compiling issues + # Compile issues _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index e4d8c35ba6ef..5ac8049bd1bf 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -464,7 +464,7 @@ class BlenderbotPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # TODO: compilation issues + # Compile issues _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index 33d9a11557f9..986435ad64d1 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -452,7 +452,7 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # TODO: compilation issues + # Compile issues _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py index affdec3dcc75..d9046ea6e8c9 100755 --- a/src/transformers/models/data2vec/modeling_data2vec_audio.py +++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py @@ -551,7 +551,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # TODO: compilation issues + # Compile issues _supports_flex_attn = False def _init_weights(self, module): diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py index 51dcc85c63bf..73a42937bd8e 100644 --- a/src/transformers/models/data2vec/modular_data2vec_audio.py +++ b/src/transformers/models/data2vec/modular_data2vec_audio.py @@ -140,7 +140,7 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - # TODO: compilation issues + # Compile issues _supports_flex_attn = False def _init_weights(self, module): diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index 115345407e69..eb366963a674 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -738,7 +738,8 @@ class HubertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py index c0454452f029..75000c95cb38 100644 --- a/src/transformers/models/hubert/modular_hubert.py +++ b/src/transformers/models/hubert/modular_hubert.py @@ -131,7 +131,8 @@ class HubertPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index f2cf438041f2..e5788ef60b74 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -530,7 +530,8 @@ class M2M100PreTrainedModel(PreTrainedModel): _no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False _supports_cache_class = True # Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model _supports_static_cache = False diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index fc9cb3a7ad59..fa6cf4a9bc85 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -468,7 +468,8 @@ class MarianPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index ed3325d4443a..2cc6a048efb2 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -498,7 +498,8 @@ class MBartPreTrainedModel(PreTrainedModel): _no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index e65b8f55c7de..c7b2d9a8661a 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -463,7 +463,8 @@ class PegasusPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 7ff113992cfd..9a5a616cf331 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -762,9 +762,10 @@ class PegasusXPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"] _supports_flash_attn_2 = True - # TODO: Flaky logits + # Flaky logits _supports_sdpa = False - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False _supports_cache_class = True _supports_static_cache = True diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index cbacb54332c7..a5b8b7ccfa9d 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -78,7 +78,8 @@ class PLBartPreTrainedModel(PreTrainedModel): _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py index 095043544d95..e6ab8a2c8c10 100644 --- a/src/transformers/models/plbart/modular_plbart.py +++ b/src/transformers/models/plbart/modular_plbart.py @@ -64,7 +64,8 @@ class PLBartPreTrainedModel(PreTrainedModel): _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index aa4ea8107110..4acee66424bb 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -529,6 +529,7 @@ class Speech2TextPreTrainedModel(PreTrainedModel): # Current tests always assume certain inputs to be passed _supports_flash_attn_2 = False _supports_sdpa = False + # Compile issues _supports_flex_attn = False def _init_weights(self, module): diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index dc960efbbcf3..3bc19a75b3c3 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -636,6 +636,7 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel): # Current tests always assume certain inputs to be passed _supports_flash_attn_2 = False _supports_sdpa = False + # Compile issues _supports_flex_attn = False def _init_weights(self, module): diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index 4fdce328e9e6..07ee6608b7ee 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -847,7 +847,8 @@ class UniSpeechPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/unispeech/modular_unispeech.py b/src/transformers/models/unispeech/modular_unispeech.py index 5a9133089aeb..795ab8596730 100644 --- a/src/transformers/models/unispeech/modular_unispeech.py +++ b/src/transformers/models/unispeech/modular_unispeech.py @@ -151,7 +151,8 @@ class UniSpeechPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 50ee4c198d25..8d9ac9c33fcc 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -850,7 +850,8 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py index f86c397a047c..9f9e7d4f3c52 100644 --- a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py @@ -161,7 +161,8 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index ae3510f175e0..fb01234e3fed 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1096,7 +1096,8 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): supports_gradient_checkpointing = True _supports_flash_attn_2 = True _supports_sdpa = True - _supports_flex_attn = True + # Compile issues + _supports_flex_attn = False def _init_weights(self, module): """Initialize the weights""" diff --git a/tests/models/data2vec/test_modeling_data2vec_audio.py b/tests/models/data2vec/test_modeling_data2vec_audio.py index 96e970beb6b9..5a8f410a70d7 100644 --- a/tests/models/data2vec/test_modeling_data2vec_audio.py +++ b/tests/models/data2vec/test_modeling_data2vec_audio.py @@ -420,6 +420,9 @@ def test_retain_grad_hidden_states_attentions(self): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/models/hubert/test_modeling_hubert.py b/tests/models/hubert/test_modeling_hubert.py index c4e65cfa5c6b..de26f4c7a4e3 100644 --- a/tests/models/hubert/test_modeling_hubert.py +++ b/tests/models/hubert/test_modeling_hubert.py @@ -370,6 +370,9 @@ def test_retain_grad_hidden_states_attentions(self): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) @@ -632,6 +635,9 @@ def test_retain_grad_hidden_states_attentions(self): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py index 0eda471e09fc..1a27192506f3 100644 --- a/tests/models/musicgen/test_modeling_musicgen.py +++ b/tests/models/musicgen/test_modeling_musicgen.py @@ -728,6 +728,9 @@ def check_musicgen_model_output_attentions_from_config( def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # force eager attention to support output attentions + config._attn_implementation = "eager" + for model_class in self.all_model_classes: self.check_musicgen_model_output_attentions(model_class, config, **inputs_dict) self.check_musicgen_model_output_attentions_from_config(model_class, config, **inputs_dict) @@ -805,6 +808,9 @@ def test_retain_grad_hidden_states_attentions(self): config.text_encoder.output_attentions = True config.decoder.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py index a5971fabd5b6..abf2edd1ce3e 100644 --- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py +++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py @@ -731,6 +731,9 @@ def check_musicgen_melody_model_output_attentions_from_config( def test_attention_outputs(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + # force eager attention to support output attentions + config._attn_implementation = "eager" + for model_class in self.all_model_classes: self.check_musicgen_melody_model_output_attentions(model_class, config, **inputs_dict) self.check_musicgen_melody_model_output_attentions_from_config(model_class, config, **inputs_dict) @@ -807,6 +810,9 @@ def test_retain_grad_hidden_states_attentions(self): config.text_encoder.output_attentions = True config.decoder.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/models/sew/test_modeling_sew.py b/tests/models/sew/test_modeling_sew.py index 9a7430c64112..2cab21cf5c92 100644 --- a/tests/models/sew/test_modeling_sew.py +++ b/tests/models/sew/test_modeling_sew.py @@ -342,6 +342,9 @@ def test_retain_grad_hidden_states_attentions(self): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/models/unispeech/test_modeling_unispeech.py b/tests/models/unispeech/test_modeling_unispeech.py index 0f8e2448320b..ebc537a47886 100644 --- a/tests/models/unispeech/test_modeling_unispeech.py +++ b/tests/models/unispeech/test_modeling_unispeech.py @@ -383,6 +383,9 @@ def test_retain_grad_hidden_states_attentions(self): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/models/unispeech_sat/test_modeling_unispeech_sat.py b/tests/models/unispeech_sat/test_modeling_unispeech_sat.py index dd8deafc227f..ec438dea96b4 100644 --- a/tests/models/unispeech_sat/test_modeling_unispeech_sat.py +++ b/tests/models/unispeech_sat/test_modeling_unispeech_sat.py @@ -423,6 +423,9 @@ def test_retain_grad_hidden_states_attentions(self): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) @@ -632,6 +635,9 @@ def test_retain_grad_hidden_states_attentions(self): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py index c47d4855ccd9..9597d2e6ef25 100644 --- a/tests/models/wav2vec2/test_modeling_wav2vec2.py +++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py @@ -570,6 +570,9 @@ def test_retain_grad_hidden_states_attentions(self): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) @@ -917,6 +920,9 @@ def test_retain_grad_hidden_states_attentions(self): config.output_hidden_states = True config.output_attentions = True + # force eager attention to support output attentions + config._attn_implementation = "eager" + # no need to test all models as different heads yield the same functionality model_class = self.all_model_classes[0] model = model_class(config) From 22114cd8c806884575934bd24aa8f7b82f7baa11 Mon Sep 17 00:00:00 2001 From: Vasqu Date: Thu, 22 May 2025 14:43:28 +0200 Subject: [PATCH 68/68] fixup last tests? --- .../test_modeling_encoder_decoder.py | 8 +++++++ .../test_modeling_speech_encoder_decoder.py | 4 ++++ .../test_modeling_vision_encoder_decoder.py | 24 +++++++++++++++++++ 3 files changed, 36 insertions(+) diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py index 6df93374137b..9aceea0359b0 100644 --- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py +++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py @@ -412,6 +412,10 @@ def check_encoder_decoder_model_output_attentions( labels, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] @@ -445,6 +449,10 @@ def check_encoder_decoder_model_output_attentions_from_config( # config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded # from the inner models' configurations. + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config) diff --git a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py index 7ac8d5631d20..28cdaf34473e 100644 --- a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py +++ b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py @@ -300,6 +300,10 @@ def check_encoder_decoder_model_output_attentions( input_features=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py index bfd9aad2332a..ffd08297f147 100644 --- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py +++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py @@ -246,6 +246,10 @@ def check_encoder_decoder_model_output_attentions( pixel_values=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] @@ -480,6 +484,10 @@ def check_encoder_decoder_model_output_attentions( pixel_values=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] @@ -670,6 +678,10 @@ def check_encoder_decoder_model_output_attentions( pixel_values=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] @@ -807,6 +819,10 @@ def check_encoder_decoder_model_output_attentions( labels=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] @@ -929,6 +945,10 @@ def check_encoder_decoder_model_output_attentions( labels=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1] @@ -1047,6 +1067,10 @@ def check_encoder_decoder_model_output_attentions( labels=None, **kwargs, ): + # force eager attention to support output attentions + config._attn_implementation = "eager" + decoder_config._attn_implementation = "eager" + # make the decoder inputs a different shape from the encoder inputs to harden the test decoder_input_ids = decoder_input_ids[:, :-1] decoder_attention_mask = decoder_attention_mask[:, :-1]