diff --git a/docs/source/en/model_doc/biogpt.md b/docs/source/en/model_doc/biogpt.md
index d7145993a89c..11ab89c9f2d9 100644
--- a/docs/source/en/model_doc/biogpt.md
+++ b/docs/source/en/model_doc/biogpt.md
@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.
@@ -40,13 +41,13 @@ This model was contributed by [kamalkraj](https://huggingface.co/kamalkraj). The
### Using Scaled Dot Product Attention (SDPA)
-PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
-encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
-[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
+PyTorch includes a native scaled dot-product attention (SDPA) operator as part of `torch.nn.functional`. This function
+encompasses several implementations that can be applied depending on the inputs and the hardware in use. See the
+[official documentation](https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html)
or the [GPU Inference](https://huggingface.co/docs/transformers/main/en/perf_infer_gpu_one#pytorch-scaled-dot-product-attention)
page for more information.
-SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
+SDPA is used by default for `torch>=2.1.1` when an implementation is available, but you may also set
`attn_implementation="sdpa"` in `from_pretrained()` to explicitly request SDPA to be used.
```
@@ -109,7 +110,7 @@ we saw the following speedups during inference.
[[autodoc]] BioGptForCausalLM
- forward
-
+
## BioGptForTokenClassification
[[autodoc]] BioGptForTokenClassification
diff --git a/docs/source/en/model_doc/blenderbot-small.md b/docs/source/en/model_doc/blenderbot-small.md
index 647a865de339..341e43c03040 100644
--- a/docs/source/en/model_doc/blenderbot-small.md
+++ b/docs/source/en/model_doc/blenderbot-small.md
@@ -21,6 +21,8 @@ rendered properly in your Markdown viewer.
+
+
Note that [`BlenderbotSmallModel`] and
@@ -52,7 +54,7 @@ found [here](https://github.com/facebookresearch/ParlAI).
## Usage tips
-Blenderbot Small is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
+Blenderbot Small is a model with absolute position embeddings so it's usually advised to pad the inputs on the right rather than
the left.
diff --git a/docs/source/en/model_doc/blenderbot.md b/docs/source/en/model_doc/blenderbot.md
index ec24d5ed7495..adfa6841e10a 100644
--- a/docs/source/en/model_doc/blenderbot.md
+++ b/docs/source/en/model_doc/blenderbot.md
@@ -21,6 +21,8 @@ rendered properly in your Markdown viewer.
+
+
## Overview
@@ -45,7 +47,7 @@ This model was contributed by [sshleifer](https://huggingface.co/sshleifer). The
## Usage tips and example
-Blenderbot is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
+Blenderbot is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
rather than the left.
An example:
@@ -71,7 +73,7 @@ An example:
`facebook/blenderbot_small_90M`, have a different architecture and consequently should be used with
[BlenderbotSmall](blenderbot-small).
-
+
## Resources
- [Causal language modeling task guide](../tasks/language_modeling)
diff --git a/docs/source/en/model_doc/marian.md b/docs/source/en/model_doc/marian.md
index 80bb73d26df1..4fcd6363559c 100644
--- a/docs/source/en/model_doc/marian.md
+++ b/docs/source/en/model_doc/marian.md
@@ -21,6 +21,8 @@ rendered properly in your Markdown viewer.
+
+
## Overview
@@ -155,7 +157,7 @@ Example of translating english to many romance languages, using old-style 2 char
>>> model = MarianMTModel.from_pretrained(model_name)
>>> translated = model.generate(**tokenizer(src_text, return_tensors="pt", padding=True))
>>> tgt_text = [tokenizer.decode(t, skip_special_tokens=True) for t in translated]
-["c'est une phrase en anglais que nous voulons traduire en français",
+["c'est une phrase en anglais que nous voulons traduire en français",
'Isto deve ir para o português.',
'Y esto al español']
```
diff --git a/docs/source/en/model_doc/nllb-moe.md b/docs/source/en/model_doc/nllb-moe.md
index 65a4812ed6ab..fc8c8c92115d 100644
--- a/docs/source/en/model_doc/nllb-moe.md
+++ b/docs/source/en/model_doc/nllb-moe.md
@@ -51,10 +51,10 @@ The original code can be found [here](https://github.com/facebookresearch/fairse
## Implementation differences with SwitchTransformers
-The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the
-highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed,
-which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden
-states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism.
+The biggest difference is the way the tokens are routed. NLLB-MoE uses a `top-2-gate` which means that for each input, only the top two experts are selected based on the
+highest predicted probabilities from the gating network, and the remaining experts are ignored. In `SwitchTransformers`, only the top-1 probabilities are computed,
+which means that tokens have less probability of being forwarded. Moreover, if a token is not routed to any expert, `SwitchTransformers` still adds its unmodified hidden
+states (kind of like a residual connection) while they are masked in `NLLB`'s top-2 routing mechanism.
## Generating with NLLB-MoE
diff --git a/docs/source/en/model_doc/pegasus.md b/docs/source/en/model_doc/pegasus.md
index bdb61e66d984..5681ac9b58a0 100644
--- a/docs/source/en/model_doc/pegasus.md
+++ b/docs/source/en/model_doc/pegasus.md
@@ -21,6 +21,8 @@ rendered properly in your Markdown viewer.
+
+
## Overview
diff --git a/docs/source/en/model_doc/pegasus_x.md b/docs/source/en/model_doc/pegasus_x.md
index 3f982263cdb1..97e50601b725 100644
--- a/docs/source/en/model_doc/pegasus_x.md
+++ b/docs/source/en/model_doc/pegasus_x.md
@@ -18,6 +18,7 @@ rendered properly in your Markdown viewer.

+
## Overview
diff --git a/docs/source/en/model_doc/plbart.md b/docs/source/en/model_doc/plbart.md
index bac567615d42..d57ee8ed99e8 100644
--- a/docs/source/en/model_doc/plbart.md
+++ b/docs/source/en/model_doc/plbart.md
@@ -18,6 +18,8 @@ rendered properly in your Markdown viewer.
## Overview
@@ -29,7 +31,7 @@ on Java, Python and English.
According to the abstract
*Code summarization and generation empower conversion between programming language (PL) and natural language (NL),
-while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART,
+while code translation avails the migration of legacy code from one PL to another. This paper introduces PLBART,
a sequence-to-sequence model capable of performing a broad spectrum of program and language understanding and generation tasks.
PLBART is pre-trained on an extensive collection of Java and Python functions and associated NL text via denoising autoencoding.
Experiments on code summarization in the English language, code generation, and code translation in seven programming languages
@@ -50,7 +52,7 @@ target text format is `[tgt_lang_code] X [eos]`. `bos` is never used.
However, for fine-tuning, in some cases no language token is provided in cases where a single language is used. Please refer to [the paper](https://arxiv.org/abs/2103.06333) to learn more about this.
-In cases where the language code is needed, the regular [`~PLBartTokenizer.__call__`] will encode source text format
+In cases where the language code is needed, the regular [`~PLBartTokenizer.__call__`] will encode source text format
when you pass texts as the first argument or with the keyword argument `text`, and will encode target text format if
it's passed with the `text_target` keyword argument.
diff --git a/src/transformers/integrations/flash_attention.py b/src/transformers/integrations/flash_attention.py
index a78166ed040b..4f76e65a847b 100644
--- a/src/transformers/integrations/flash_attention.py
+++ b/src/transformers/integrations/flash_attention.py
@@ -3,8 +3,11 @@
import torch
from ..modeling_flash_attention_utils import _flash_attention_forward, flash_attn_supports_top_left_mask
+from ..utils import logging
+logger = logging.get_logger(__name__)
+
_use_top_left_mask = flash_attn_supports_top_left_mask()
@@ -20,6 +23,12 @@ def flash_attention_forward(
softcap: Optional[float] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
+ if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None:
+ logger.warning_once(
+ "`flash_attention_2` does not support `output_attentions=True` or `head_mask`."
+ " Please set your attention to `eager` if you want any of these features."
+ )
+
# This is before the transpose
seq_len = query.shape[2]
diff --git a/src/transformers/integrations/flex_attention.py b/src/transformers/integrations/flex_attention.py
index afdaba5199de..cc9787657bc7 100644
--- a/src/transformers/integrations/flex_attention.py
+++ b/src/transformers/integrations/flex_attention.py
@@ -31,13 +31,15 @@
import torch
from packaging import version
-from ..utils import is_torch_flex_attn_available
+from ..utils import is_torch_flex_attn_available, logging
from ..utils.import_utils import _torch_version, is_torchdynamo_compiling
if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask, flex_attention
- from torch.nn.attention.flex_attention import create_block_mask as create_block_causal_mask_flex
+ from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
+
+
+logger = logging.get_logger(__name__)
class WrappedFlexAttention:
@@ -98,21 +100,23 @@ def compile_friendly_flex_attention(
Offset = Union[torch.Tensor, int]
+# TODO: deprecate / rename to make_flex_block_mask for clarity as it's not only causal anymore
def make_flex_block_causal_mask(
attention_mask_2d: torch.Tensor,
attention_chunk_size: Optional[int] = None,
query_length=None,
key_length=None,
offsets: Optional[Tuple[Offset, Offset]] = None,
+ is_causal: Optional[bool] = True,
) -> "BlockMask":
"""
IMPORTANT NOTICE: This function is deprecated in favor of using the mask primitives in `masking_utils.py`,
and will be removed in a future version without warnings. New code should not use it. It is only kept here
for BC for now, while models using it are being patched accordingly.
- Create a block causal document mask for a batch of sequences, both packed and unpacked.
- Create Block causal logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
- The resultant BlockMask is a compressed representation of the full block causal
+ Create a block (causal) document mask for a batch of sequences, both packed and unpacked.
+ Create Block (causal) logic and passing it into :func:`torch.nn.attention.flex_attention.create_block_mask`.
+ The resultant BlockMask is a compressed representation of the full (causal) block
mask. BlockMask is essential for performant computation of flex attention.
See: https://pytorch.org/blog/flexattention/
@@ -170,7 +174,21 @@ def chunk_causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
causal_doc_mask = causal_mask_mod(batch_idx, head_idx, q_idx, kv_idx)
return chunk_mask & causal_doc_mask
- mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
+ def default_mask_mod(batch_idx, head_idx, q_idx, kv_idx):
+ """
+ Utilizes default attention mask to enable encoder and encoder-decoder
+ attention masks.
+ """
+ document_mask = document_ids[batch_idx, q_idx] == document_ids[batch_idx, kv_idx]
+ # kv indexing is crucial in order to work correctly
+ padding_mask = attention_mask_2d[batch_idx, kv_idx] > 0
+ final_mask = padding_mask & document_mask
+ return final_mask
+
+ if not is_causal:
+ mask_mod_maybe_combined = default_mask_mod
+ else:
+ mask_mod_maybe_combined = causal_mask_mod if attention_chunk_size is None else chunk_causal_mask_mod
if offsets is not None:
q_offset = offsets[0]
@@ -182,7 +200,8 @@ def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
return mask_mod_maybe_combined(batch_idx, head_idx, offset_q, offset_kv)
else:
mask_mod = mask_mod_maybe_combined
- return create_block_causal_mask_flex(
+
+ return create_block_mask(
mask_mod=mask_mod,
B=batch_size,
H=None, # attention head
@@ -216,21 +235,33 @@ def flex_attention_forward(
head_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
+ if kwargs.get("output_attentions", False) or head_mask is not None:
+ logger.warning_once(
+ "`flex_attention` does not support `output_attentions=True` or `head_mask`."
+ " Please set your attention to `eager` if you want any of these features."
+ )
+
+ if kwargs.get("dropout", 0.0) > 0:
+ raise ValueError(
+ "`flex_attention` does not support `dropout`. Please use it with inference"
+ " only (`model.eval()`) or turn off the attention dropout in the respective config."
+ )
+
block_mask = None
- causal_mask = None
+ score_mask = None
if isinstance(attention_mask, BlockMask):
block_mask = attention_mask
else:
- causal_mask = attention_mask
+ score_mask = attention_mask
- if causal_mask is not None:
- causal_mask = causal_mask[:, :, :, : key.shape[-2]]
+ if score_mask is not None:
+ score_mask = score_mask[:, :, :, : key.shape[-2]]
def score_mod(score, batch_idx, head_idx, q_idx, kv_idx):
if softcap is not None:
score = softcap * torch.tanh(score / softcap)
- if causal_mask is not None:
- score = score + causal_mask[batch_idx][0][q_idx][kv_idx]
+ if score_mask is not None:
+ score = score + score_mask[batch_idx][0][q_idx][kv_idx]
if head_mask is not None:
score = score + head_mask[batch_idx][head_idx][0][0]
return score
diff --git a/src/transformers/integrations/sdpa_attention.py b/src/transformers/integrations/sdpa_attention.py
index 9c924c048ad5..247cd2821679 100644
--- a/src/transformers/integrations/sdpa_attention.py
+++ b/src/transformers/integrations/sdpa_attention.py
@@ -2,6 +2,11 @@
import torch
+from ..utils import logging
+
+
+logger = logging.get_logger(__name__)
+
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
@@ -26,13 +31,18 @@ def sdpa_attention_forward(
is_causal: Optional[bool] = None,
**kwargs,
) -> Tuple[torch.Tensor, None]:
+ if kwargs.get("output_attentions", False) or kwargs.get("head_mask", None) is not None:
+ logger.warning_once(
+ "`sdpa` attention does not support `output_attentions=True` or `head_mask`."
+ " Please set your attention to `eager` if you want any of these features."
+ )
+
if hasattr(module, "num_key_value_groups"):
key = repeat_kv(key, module.num_key_value_groups)
value = repeat_kv(value, module.num_key_value_groups)
- causal_mask = attention_mask
- if attention_mask is not None and causal_mask.ndim == 4:
- causal_mask = causal_mask[:, :, :, : key.shape[-2]]
+ if attention_mask is not None and attention_mask.ndim == 4:
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
# SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
# Reference: https://github.com/pytorch/pytorch/issues/112577.
@@ -44,7 +54,9 @@ def sdpa_attention_forward(
# in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
# Note that it is important to check first for the shape, otherwise compile will fail with `argument 'is_causal' must be bool, not SymBool`
if is_causal is None:
- is_causal = query.shape[2] > 1 and causal_mask is None
+ # The last condition is for encoder (decoder) models which specify this by passing their own `is_causal` flag
+ # This is mainly due to those models having mixed implementations for encoder, decoder, and encoder-decoder attns
+ is_causal = query.shape[2] > 1 and attention_mask is None and getattr(module, "is_causal", True)
# Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
# We convert it to a bool for the SDPA kernel that only accepts bools.
@@ -55,7 +67,7 @@ def sdpa_attention_forward(
query,
key,
value,
- attn_mask=causal_mask,
+ attn_mask=attention_mask,
dropout_p=dropout,
scale=scaling,
is_causal=is_causal,
diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py
index 97e95b4161b0..0ad3947815b2 100644
--- a/src/transformers/modeling_utils.py
+++ b/src/transformers/modeling_utils.py
@@ -165,6 +165,7 @@
if is_kernels_available():
from kernels import get_kernel
+
logger = logging.get_logger(__name__)
diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py
index 9f7f1515a278..0a41692f69cd 100644
--- a/src/transformers/models/autoformer/modeling_autoformer.py
+++ b/src/transformers/models/autoformer/modeling_autoformer.py
@@ -26,14 +26,21 @@
from torch import nn
from ...activations import ACT2FN
-from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+)
from ...modeling_outputs import BaseModelOutput, ModelOutput, SampleTSPredictionOutput, Seq2SeqTSPredictionOutput
from ...modeling_utils import PreTrainedModel
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
-from ...utils import auto_docstring, logging
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from .configuration_autoformer import AutoformerConfig
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
logger = logging.get_logger(__name__)
@@ -904,6 +911,29 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerEncoder with TimeSeriesTransformer->Autoformer,TimeSeries->Autoformer
class AutoformerEncoder(AutoformerPreTrainedModel):
@@ -983,10 +1013,10 @@ def forward(
hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- # expand attention_mask
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py
index 01f7f19a79e5..60d9cdba2aee 100755
--- a/src/transformers/models/bart/modeling_bart.py
+++ b/src/transformers/models/bart/modeling_bart.py
@@ -17,7 +17,7 @@
import copy
import math
import warnings
-from typing import List, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -32,7 +32,7 @@
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
-from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -42,7 +42,8 @@
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
)
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
is_torch_flex_attn_available,
@@ -53,13 +54,7 @@
if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask
-
- from ...integrations.flex_attention import make_flex_block_causal_mask
-
-
-if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -119,6 +114,36 @@ def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
class BartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -170,151 +195,25 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
-
- if past_key_value is not None:
- if isinstance(past_key_value, EncoderDecoderCache):
- is_updated = past_key_value.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_value = past_key_value.cross_attention_cache
- else:
- curr_past_key_value = past_key_value.self_attention_cache
- else:
- curr_past_key_value = past_key_value
-
- current_states = key_value_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_value is not None and is_updated:
- # reuse k,v, cross_attentions
- key_states = curr_past_key_value.key_cache[self.layer_idx]
- value_states = curr_past_key_value.value_cache[self.layer_idx]
- else:
- key_states = self.k_proj(current_states)
- value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- if past_key_value is not None:
- # save all key/value_states to cache to be re-used for fast auto-regressive generation
- cache_position = cache_position if not is_cross_attention else None
- key_states, value_states = curr_past_key_value.update(
- key_states, value_states, self.layer_idx, {"cache_position": cache_position}
- )
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention:
- past_key_value.is_updated[self.layer_idx] = True
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped, past_key_value
-
-
-class BartFlashAttention2(BartAttention):
- """
- Bart flash attention module. This module inherits from `BartAttention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Cache] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- cache_position: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # BartFlashAttention2 attention does not support output_attentions
- if output_attentions:
- raise ValueError(
- "BartSdpaAttention2 attention does not support `output_attentions`. "
- "Use the argument `attn_implementation='eager'` when loading the model."
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- bsz, q_len, _ = hidden_states.size()
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
@@ -335,8 +234,8 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -348,163 +247,27 @@ def forward(
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
- # to be able to avoid many of these transpose/reshape/view.
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (LlamaRMSNorm handles it correctly)
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- attn_output = _flash_attention_forward(
+ attn_output, attn_weights = attention_interface(
+ self,
query_states,
key_states,
value_states,
attention_mask,
- q_len,
- dropout=self.dropout if self.training else 0.0,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1)
- attn_output = self.out_proj(attn_output)
-
- return attn_output, None, past_key_value
-
-
-class BartSdpaAttention(BartAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Cache] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- cache_position: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "BartModel is using BartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- if past_key_value is not None:
- if isinstance(past_key_value, EncoderDecoderCache):
- is_updated = past_key_value.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_value = past_key_value.cross_attention_cache
- else:
- curr_past_key_value = past_key_value.self_attention_cache
- else:
- curr_past_key_value = past_key_value
-
- current_states = key_value_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_value is not None and is_updated:
- # reuse k,v, cross_attentions
- key_states = curr_past_key_value.key_cache[self.layer_idx]
- value_states = curr_past_key_value.value_cache[self.layer_idx]
- else:
- key_states = self.k_proj(current_states)
- value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- if past_key_value is not None:
- # save all key/value_states to cache to be re-used for fast auto-regressive generation
- cache_position = cache_position if not is_cross_attention else None
- key_states, value_states = curr_past_key_value.update(
- key_states, value_states, self.layer_idx, {"cache_position": cache_position}
- )
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention:
- past_key_value.is_updated[self.layer_idx] = True
-
- causal_mask = None
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
-
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query_states.device.type == "cuda" and causal_mask is not None:
- query_states = query_states.contiguous()
- key_states = key_states.contiguous()
- value_states = value_states.contiguous()
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
- is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
-
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attn_mask=causal_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, None, past_key_value
-
-
-BART_ATTENTION_CLASSES = {
- "eager": BartAttention,
- "sdpa": BartSdpaAttention,
- "flash_attention_2": BartFlashAttention2,
-}
+ return attn_output, attn_weights, past_key_value
class BartEncoderLayer(nn.Module):
@@ -512,7 +275,7 @@ def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = BartAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -583,7 +346,7 @@ def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = BartAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -597,7 +360,7 @@ def __init__(self, config: BartConfig, layer_idx: Optional[int] = None):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = BART_ATTENTION_CLASSES[config._attn_implementation](
+ self.encoder_attn = BartAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -671,6 +434,7 @@ def forward(
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
+ cache_position=cache_position,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
@@ -730,6 +494,8 @@ class BartPreTrainedModel(PreTrainedModel):
_skip_keys_device_placement = "past_key_values"
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
_supports_cache_class = True
_supports_static_cache = True
@@ -757,24 +523,54 @@ def dummy_inputs(self):
}
return dummy_inputs
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
def _update_causal_mask(
self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
- output_attentions: bool = False,
):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
return attention_mask
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
@@ -782,7 +578,7 @@ def _update_causal_mask(
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@@ -816,7 +612,6 @@ def _update_causal_mask(
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
@@ -882,6 +677,41 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
class PretrainedBartModel(BartPreTrainedModel):
def __init_subclass__(self):
@@ -932,8 +762,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
embed_dim,
)
self.layers = nn.ModuleList([BartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)])
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
- self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.gradient_checkpointing = False
@@ -1019,18 +847,10 @@ def forward(
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- # expand attention_mask
- if attention_mask is not None:
- if self._use_flash_attention_2:
- attention_mask = attention_mask if 0 in attention_mask else None
- elif self._use_sdpa and head_mask is None and not output_attentions:
- # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
- # the manual implementation that requires a 4D causal mask in all cases.
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
- else:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@@ -1116,8 +936,6 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No
config.d_model,
)
self.layers = nn.ModuleList([BartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
- self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
@@ -1232,12 +1050,18 @@ def forward(
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-
- if input_ids is not None:
- input_ids = input_ids.view(-1, input_ids.shape[-1])
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
+ inputs_embeds = self.embed_tokens(input)
# initialize `past_key_values`
return_legacy_cache = False
@@ -1267,38 +1091,25 @@ def forward(
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
- causal_mask = self._update_causal_mask(
+
+ attention_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
self_attn_cache,
- output_attentions,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self._use_flash_attention_2:
- encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
- elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
- # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
- # the manual implementation that requires a 4D causal mask in all cases.
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
- encoder_attention_mask,
- inputs_embeds.dtype,
- tgt_len=seq_length,
- )
- else:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length
- )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
# embed positions
- position_ids = self.embed_positions(input, past_key_values_length, position_ids=cache_position)
- position_ids = position_ids.to(inputs_embeds.device)
+ positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position)
+ positions = positions.to(inputs_embeds.device)
- hidden_states = inputs_embeds + position_ids
+ hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -1331,7 +1142,7 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
- causal_mask,
+ attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
@@ -1344,7 +1155,7 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
- attention_mask=causal_mask,
+ attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
index 4ff34b9ef259..d49d4e65bd70 100755
--- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
+++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py
@@ -16,7 +16,7 @@
import copy
import math
-from typing import List, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -29,7 +29,9 @@
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -39,20 +41,14 @@
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
)
-from ...modeling_utils import PreTrainedModel
-from ...utils import (
- auto_docstring,
- is_torch_flex_attn_available,
- is_torchdynamo_compiling,
- logging,
-)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
from .configuration_bigbird_pegasus import BigBirdPegasusConfig
if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask
-
- from ...integrations.flex_attention import make_flex_block_causal_mask
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -1179,6 +1175,37 @@ def forward(
return outputs
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with BartConfig->BigBirdPegasusConfig, Bart->BigBirdPegasusDecoder
class BigBirdPegasusDecoderAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -1231,17 +1258,25 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
@@ -1262,8 +1297,8 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -1275,66 +1310,27 @@ def forward(
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
+ return attn_output, attn_weights, past_key_value
class BigBirdPegasusEncoderLayer(nn.Module):
@@ -1434,6 +1430,7 @@ def __init__(self, config: BigBirdPegasusConfig, layer_idx: Optional[int] = None
dropout=config.attention_dropout,
is_decoder=True,
bias=config.use_bias,
+ config=config,
layer_idx=layer_idx,
)
self.dropout = config.dropout
@@ -1447,6 +1444,7 @@ def __init__(self, config: BigBirdPegasusConfig, layer_idx: Optional[int] = None
dropout=config.attention_dropout,
is_decoder=True,
bias=config.use_bias,
+ config=config,
layer_idx=layer_idx,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
@@ -1510,7 +1508,6 @@ def forward(
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
- # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
hidden_states, cross_attn_weights, past_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
@@ -1602,24 +1599,33 @@ def dummy_inputs(self):
}
return dummy_inputs
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
def _update_causal_mask(
self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
- output_attentions: bool = False,
):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
return attention_mask
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
@@ -1627,7 +1633,7 @@ def _update_causal_mask(
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@@ -1661,7 +1667,6 @@ def _update_causal_mask(
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
@@ -1727,6 +1732,42 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel):
"""
@@ -2172,9 +2213,13 @@ def forward(
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-
- if input_ids is not None:
- input_ids = input_ids.view(-1, input_ids.shape[-1])
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
@@ -2207,28 +2252,26 @@ def forward(
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
- causal_mask = self._update_causal_mask(
+
+ attention_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
self_attn_cache,
- output_attentions,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length
- )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
# embed positions
- position_ids = cache_position.unsqueeze(0)
- position_ids = self.embed_positions(
- (batch_size, seq_length), past_key_values_length, position_ids=position_ids
- )
- position_ids = position_ids.to(inputs_embeds.device)
- hidden_states = inputs_embeds + position_ids
+ positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position)
+ positions = positions.to(inputs_embeds.device)
+
+ hidden_states = inputs_embeds + positions
+
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
# decoder layers
@@ -2258,7 +2301,7 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
- causal_mask,
+ attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
@@ -2271,7 +2314,7 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
- attention_mask=causal_mask,
+ attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
@@ -2979,7 +3022,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- cache_position: Optional[torch.Tensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py
index d93b6f6ae2d8..0b2a0dc2749f 100755
--- a/src/transformers/models/biogpt/modeling_biogpt.py
+++ b/src/transformers/models/biogpt/modeling_biogpt.py
@@ -1,3 +1,9 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/biogpt/modular_biogpt.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_biogpt.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
#
@@ -12,56 +18,46 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""PyTorch BioGPT model."""
import math
-from typing import Optional, Tuple, Union
+from functools import partial
+from typing import Callable, Optional, Tuple, Union
import torch
-import torch.utils.checkpoint
-from torch import nn
+import torch.nn as nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
from ...generation import GenerationMixin
-from ...modeling_attn_mask_utils import (
- AttentionMaskConverter,
-)
+from ...modeling_attn_mask_utils import AttentionMaskConverter
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutputWithPastAndCrossAttentions,
CausalLMOutputWithCrossAttentions,
SequenceClassifierOutputWithPast,
TokenClassifierOutput,
)
-from ...modeling_utils import PreTrainedModel
-from ...utils import (
- auto_docstring,
- is_torch_flex_attn_available,
- is_torchdynamo_compiling,
- logging,
-)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import LossKwargs, auto_docstring, is_torch_flex_attn_available, logging
from .configuration_biogpt import BioGptConfig
if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask
-
- from ...integrations.flex_attention import make_flex_block_causal_mask
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
logger = logging.get_logger(__name__)
-# copied from transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding with OPT->BioGpt
-# TODO @ArthurZucker bring copied from back
class BioGptLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
"""
def __init__(self, num_embeddings: int, embedding_dim: int):
- # BioGpt is set up so that if padding_idx is specified then offset the embedding ids by 2
+ # BIOGPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)
@@ -70,22 +66,19 @@ def forward(
self,
attention_mask: torch.LongTensor,
past_key_values_length: int = 0,
- position_ids: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
- if position_ids is None:
- attention_mask = attention_mask.long()
-
- # create positions depending on attention_mask
- positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1
+ if position_ids is None:
+ position_ids = torch.cumsum(attention_mask, dim=1)
+ position_ids = (position_ids * attention_mask - 1).long()
# cut positions if `past_key_values_length` is > 0
- position_ids = positions[:, past_key_values_length:]
+ position_ids = position_ids[:, past_key_values_length:]
return super().forward(position_ids + self.offset)
-# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BioGpt
class BioGptScaledWordEmbedding(nn.Embedding):
"""
This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
@@ -99,7 +92,36 @@ def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
-# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BioGpt
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
class BioGptAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -151,148 +173,25 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
-
- if past_key_value is not None:
- if isinstance(past_key_value, EncoderDecoderCache):
- is_updated = past_key_value.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_value = past_key_value.cross_attention_cache
- else:
- curr_past_key_value = past_key_value.self_attention_cache
- else:
- curr_past_key_value = past_key_value
-
- current_states = key_value_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_value is not None and is_updated:
- # reuse k,v, cross_attentions
- key_states = curr_past_key_value.key_cache[self.layer_idx]
- value_states = curr_past_key_value.value_cache[self.layer_idx]
- else:
- key_states = self.k_proj(current_states)
- value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- if past_key_value is not None:
- # save all key/value_states to cache to be re-used for fast auto-regressive generation
- cache_position = cache_position if not is_cross_attention else None
- key_states, value_states = curr_past_key_value.update(
- key_states, value_states, self.layer_idx, {"cache_position": cache_position}
- )
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention:
- past_key_value.is_updated[self.layer_idx] = True
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped, past_key_value
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
-# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->BioGpt
-class BioGptSdpaAttention(BioGptAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Cache] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- cache_position: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "BioGptModel is using BioGptSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, tgt_len, _ = hidden_states.size()
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
@@ -313,8 +212,8 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -326,47 +225,27 @@ def forward(
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
- causal_mask = None
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
-
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query_states.device.type == "cuda" and causal_mask is not None:
- query_states = query_states.contiguous()
- key_states = key_states.contiguous()
- value_states = value_states.contiguous()
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
- is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
-
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
- attn_output = torch.nn.functional.scaled_dot_product_attention(
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
query_states,
key_states,
value_states,
- attn_mask=causal_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, None, past_key_value
-
-
-BIOGPT_ATTENTION_CLASSES = {
- "eager": BioGptAttention,
- "sdpa": BioGptSdpaAttention,
-}
+ return attn_output, attn_weights, past_key_value
class BioGptDecoderLayer(nn.Module):
@@ -374,12 +253,13 @@ def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.hidden_size
- self.self_attn = BIOGPT_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = BioGptAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_probs_dropout_prob,
is_decoder=True,
is_causal=True,
+ config=config,
layer_idx=layer_idx,
)
self.dropout = config.hidden_dropout_prob
@@ -400,7 +280,9 @@ def forward(
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
+ position_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.Tensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@@ -431,7 +313,9 @@ def forward(
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
+ position_ids=position_ids,
cache_position=cache_position,
+ **flash_attn_kwargs,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
@@ -462,7 +346,10 @@ class BioGptPreTrainedModel(PreTrainedModel):
config_class = BioGptConfig
base_model_prefix = "biogpt"
supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
_supports_cache_class = True
_supports_static_cache = True
@@ -482,24 +369,33 @@ def _init_weights(self, module):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
def _update_causal_mask(
self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
- output_attentions: bool = False,
):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
return attention_mask
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
@@ -507,7 +403,7 @@ def _update_causal_mask(
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@@ -541,7 +437,6 @@ def _update_causal_mask(
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
@@ -628,7 +523,6 @@ def __init__(self, config: BioGptConfig):
self.layer_norm = nn.LayerNorm(self.embed_dim)
self.gradient_checkpointing = False
- self._use_sdpa = config._attn_implementation == "sdpa"
# Initialize weights and apply final processing
self.post_init()
@@ -652,7 +546,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
- **kwargs, # NOOP kwargs, for now
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -663,18 +557,24 @@ def forward(
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
-
- if input_ids is not None:
- input_ids = input_ids.view(-1, input_ids.shape[-1])
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
+ inputs_embeds = self.embed_tokens(input)
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
)
use_cache = False
@@ -696,7 +596,7 @@ def forward(
past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
)
- if attention_mask is None and not is_torchdynamo_compiling():
+ if attention_mask is None:
# required mask seq length can be calculated via length of past cache
mask_seq_length = past_key_values_length + seq_length
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
@@ -706,27 +606,37 @@ def forward(
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
+
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
self_attn_cache,
- output_attentions,
)
# embed positions
if position_ids is None:
- position_ids = cache_position.unsqueeze(0)
-
- position_ids = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
-
- hidden_states = inputs_embeds + position_ids
+ # position_ids = cache_position.unsqueeze(0)
+ position_ids = torch.cumsum(attention_mask, dim=1)
+ position_ids = (position_ids * attention_mask - 1).long()
+ # cut positions if `past_seen_tokens` is > 0
+ position_ids = position_ids[:, past_key_values_length:]
+
+ positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
+ hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = None
- next_decoder_cache = None
+ next_decoder_cache = () if use_cache else None
for idx, decoder_layer in enumerate(self.layers):
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
@@ -739,13 +649,14 @@ def forward(
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
- decoder_layer.__call__,
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
hidden_states,
causal_mask,
head_mask[idx] if head_mask is not None else None,
None,
output_attentions,
use_cache,
+ position_ids,
cache_position,
)
else:
@@ -756,7 +667,9 @@ def forward(
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
+ position_ids=position_ids,
cache_position=cache_position,
+ **flash_attn_kwargs,
)
hidden_states = layer_outputs[0]
@@ -792,6 +705,9 @@ def forward(
)
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+
@auto_docstring(
custom_intro="""
BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
@@ -830,7 +746,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
- **kwargs,
+ **kwargs: Unpack[KwargsForCausalLM],
) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
r"""
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -852,6 +768,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
+ **kwargs,
)
sequence_output = outputs[0]
@@ -916,9 +833,11 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
) -> Union[Tuple, TokenClassifierOutput]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -935,9 +854,11 @@ def forward(
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
+ position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ cache_position=cache_position,
)
hidden_states = transformer_outputs[0]
@@ -1004,9 +925,11 @@ def forward(
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
) -> Union[Tuple, SequenceClassifierOutputWithPast]:
r"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -1023,9 +946,11 @@ def forward(
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
+ position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ cache_position=cache_position,
)
hidden_states = transformer_outputs[0]
logits = self.score(hidden_states)
diff --git a/src/transformers/models/biogpt/modular_biogpt.py b/src/transformers/models/biogpt/modular_biogpt.py
new file mode 100644
index 000000000000..4bd675be9216
--- /dev/null
+++ b/src/transformers/models/biogpt/modular_biogpt.py
@@ -0,0 +1,851 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Team and Microsoft Research AI4Science All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch BioGPT model."""
+
+import math
+from functools import partial
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
+
+from ...activations import ACT2FN
+from ...cache_utils import Cache, EncoderDecoderCache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import (
+ AttentionMaskConverter,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_outputs import (
+ BaseModelOutputWithPastAndCrossAttentions,
+ CausalLMOutputWithCrossAttentions,
+ SequenceClassifierOutputWithPast,
+ TokenClassifierOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import (
+ LossKwargs,
+ auto_docstring,
+ is_torch_flex_attn_available,
+ logger,
+)
+from ..bart.modeling_bart import (
+ BartAttention,
+ BartDecoderLayer,
+ BartScaledWordEmbedding,
+)
+from ..opt.modeling_opt import OPTLearnedPositionalEmbedding
+from .configuration_biogpt import BioGptConfig
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
+
+
+class BioGptLearnedPositionalEmbedding(OPTLearnedPositionalEmbedding):
+ def forward(
+ self,
+ attention_mask: torch.LongTensor,
+ past_key_values_length: int = 0,
+ position_ids: Optional[torch.LongTensor] = None,
+ ):
+ """`input_ids_shape` is expected to be [bsz x seqlen]."""
+ super().forward(attention_mask, past_key_values_length, position_ids)
+
+
+class BioGptScaledWordEmbedding(BartScaledWordEmbedding):
+ pass
+
+
+class BioGptAttention(BartAttention):
+ pass
+
+
+class BioGptDecoderLayer(BartDecoderLayer):
+ def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None):
+ super().__init__(config)
+ self.embed_dim = config.hidden_size
+
+ self.self_attn = BioGptAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.num_attention_heads,
+ dropout=config.attention_probs_dropout_prob,
+ is_decoder=True,
+ is_causal=True,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.dropout = config.hidden_dropout_prob
+ self.activation_fn = ACT2FN[config.hidden_act]
+
+ self.fc1 = nn.Linear(self.embed_dim, config.intermediate_size)
+ self.fc2 = nn.Linear(config.intermediate_size, self.embed_dim)
+
+ del self.encoder_attn
+ del self.encoder_attn_layer_norm
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ position_ids: Optional[torch.LongTensor] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ use_cache (`bool`, *optional*):
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
+ (see `past_key_values`).
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
+ cache in the correct position and to infer the complete sequence length.
+ """
+ residual = hidden_states
+
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Self Attention
+ hidden_states, self_attn_weights, past_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ **flash_attn_kwargs,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+ hidden_states = self.fc1(hidden_states)
+ hidden_states = self.activation_fn(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights,)
+
+ if use_cache:
+ outputs += (past_key_value,)
+
+ return outputs
+
+
+@auto_docstring
+class BioGptPreTrainedModel(PreTrainedModel):
+ config_class = BioGptConfig
+ base_model_prefix = "biogpt"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
+ _supports_cache_class = True
+ _supports_static_cache = True
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, nn.LayerNorm):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ ):
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
+ return attention_mask
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+
+@auto_docstring
+class BioGptModel(BioGptPreTrainedModel):
+ def __init__(self, config: BioGptConfig):
+ super().__init__(config)
+ self.config = config
+ self.layerdrop = config.layerdrop
+ self.dropout = config.hidden_dropout_prob
+ self.embed_dim = config.hidden_size
+ self.padding_idx = config.pad_token_id
+ embed_scale = math.sqrt(config.hidden_size) if config.scale_embedding else 1.0
+
+ self.embed_tokens = BioGptScaledWordEmbedding(
+ config.vocab_size, self.embed_dim, self.padding_idx, embed_scale=embed_scale
+ )
+ self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim)
+
+ self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
+ self.layer_norm = nn.LayerNorm(self.embed_dim)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_input_embeddings(self):
+ return self.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.embed_tokens = value
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **flash_attn_kwargs: Unpack[FlashAttentionKwargs],
+ ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if (input_ids is None) ^ (inputs_embeds is not None):
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # initialize past_key_values
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
+
+ batch_size, seq_length = inputs_embeds.size()[:-1]
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
+ )
+
+ if attention_mask is None:
+ # required mask seq length can be calculated via length of past cache
+ mask_seq_length = past_key_values_length + seq_length
+ attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
+
+ self_attn_cache = (
+ past_key_values.self_attention_cache
+ if isinstance(past_key_values, EncoderDecoderCache)
+ else past_key_values
+ )
+
+ causal_mask = self._update_causal_mask(
+ attention_mask,
+ inputs_embeds,
+ cache_position,
+ self_attn_cache,
+ )
+
+ # embed positions
+ if position_ids is None:
+ # position_ids = cache_position.unsqueeze(0)
+ position_ids = torch.cumsum(attention_mask, dim=1)
+ position_ids = (position_ids * attention_mask - 1).long()
+ # cut positions if `past_seen_tokens` is > 0
+ position_ids = position_ids[:, past_key_values_length:]
+
+ positions = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids)
+ hidden_states = inputs_embeds + positions
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ all_cross_attentions = None
+ next_decoder_cache = () if use_cache else None
+
+ for idx, decoder_layer in enumerate(self.layers):
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop:
+ continue
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ partial(decoder_layer.__call__, **flash_attn_kwargs),
+ hidden_states,
+ causal_mask,
+ head_mask[idx] if head_mask is not None else None,
+ None,
+ output_attentions,
+ use_cache,
+ position_ids,
+ cache_position,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask=causal_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ past_key_value=past_key_values,
+ output_attentions=output_attentions,
+ use_cache=use_cache,
+ position_ids=position_ids,
+ cache_position=cache_position,
+ **flash_attn_kwargs,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ hidden_states = self.layer_norm(hidden_states)
+
+ next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = past_key_values.to_legacy_cache()
+
+ if not return_dict:
+ return tuple(
+ v
+ for v in [hidden_states, next_cache, all_hidden_states, all_self_attns, all_cross_attentions]
+ if v is not None
+ )
+ return BaseModelOutputWithPastAndCrossAttentions(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ cross_attentions=all_cross_attentions,
+ )
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): ...
+
+
+@auto_docstring(
+ custom_intro="""
+ BioGPT Model with a `language modeling` head on top for CLM fine-tuning.
+ """
+)
+class BioGptForCausalLM(BioGptPreTrainedModel, GenerationMixin):
+ _tied_weights_keys = ["output_projection.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.biogpt = BioGptModel(config)
+ self.output_projection = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def get_output_embeddings(self):
+ return self.output_projection
+
+ def set_output_embeddings(self, new_embeddings):
+ self.output_projection = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ **kwargs: Unpack[KwargsForCausalLM],
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ outputs = self.biogpt(
+ input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ past_key_values=past_key_values,
+ use_cache=use_cache,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ sequence_output = outputs[0]
+ prediction_scores = self.output_projection(sequence_output)
+
+ lm_loss = None
+ if labels is not None:
+ lm_loss = self.loss_function(
+ prediction_scores,
+ labels,
+ vocab_size=self.config.vocab_size,
+ **kwargs,
+ )
+
+ if not return_dict:
+ output = (prediction_scores,) + outputs[1:]
+ return ((lm_loss,) + output) if lm_loss is not None else output
+
+ return CausalLMOutputWithCrossAttentions(
+ loss=lm_loss,
+ logits=prediction_scores,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ cross_attentions=outputs.cross_attentions,
+ )
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
+ )
+ return reordered_past
+
+
+@auto_docstring
+class BioGptForTokenClassification(BioGptPreTrainedModel):
+ def __init__(self, config):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+
+ self.biogpt = BioGptModel(config)
+ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None:
+ classifier_dropout = config.classifier_dropout
+ else:
+ classifier_dropout = config.hidden_dropout_prob
+ self.dropout = nn.Dropout(classifier_dropout)
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
+
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ token_type_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, TokenClassifierOutput]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.biogpt(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ hidden_states = transformer_outputs[0]
+ hidden_states = self.dropout(hidden_states)
+ logits = self.classifier(hidden_states)
+
+ loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ # Only keep active parts of the loss
+ if attention_mask is not None:
+ active_loss = attention_mask.view(-1) == 1
+ active_logits = logits.view(-1, self.num_labels)
+ active_labels = torch.where(
+ active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
+ )
+ loss = loss_fct(active_logits, active_labels)
+ else:
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
+
+ if not return_dict:
+ output = (logits,) + transformer_outputs[2:]
+ return ((loss,) + output) if loss is not None else output
+
+ return TokenClassifierOutput(
+ loss=loss,
+ logits=logits,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The BioGpt Model transformer with a sequence classification head on top (linear layer).
+
+ [`BioGptForSequenceClassification`] uses the last token in order to do the classification, as other causal models
+ (e.g. GPT-2) do.
+
+ Since it does classification on the last token, it is required to know the position of the last token. If a
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
+ each row of the batch).
+ """
+)
+class BioGptForSequenceClassification(BioGptPreTrainedModel):
+ def __init__(self, config: BioGptConfig):
+ super().__init__(config)
+ self.num_labels = config.num_labels
+ self.biogpt = BioGptModel(config)
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ head_mask: Optional[torch.FloatTensor] = None,
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.biogpt(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ position_ids=position_ids,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if input_ids is not None:
+ batch_size, sequence_length = input_ids.shape[:2]
+ else:
+ batch_size, sequence_length = inputs_embeds.shape[:2]
+
+ if self.config.pad_token_id is None:
+ sequence_length = -1
+ else:
+ if input_ids is not None:
+ sequence_length = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
+ else:
+ sequence_length = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_length]
+
+ loss = None
+ if labels is not None:
+ if self.config.problem_type is None:
+ if self.num_labels == 1:
+ self.config.problem_type = "regression"
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
+ self.config.problem_type = "single_label_classification"
+ else:
+ self.config.problem_type = "multi_label_classification"
+
+ if self.config.problem_type == "regression":
+ loss_fct = MSELoss()
+ if self.num_labels == 1:
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
+ else:
+ loss = loss_fct(pooled_logits, labels)
+ elif self.config.problem_type == "single_label_classification":
+ loss_fct = CrossEntropyLoss()
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
+ elif self.config.problem_type == "multi_label_classification":
+ loss_fct = BCEWithLogitsLoss()
+ loss = loss_fct(pooled_logits, labels)
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ def get_input_embeddings(self):
+ return self.biogpt.embed_tokens
+
+ def set_input_embeddings(self, value):
+ self.biogpt.embed_tokens = value
+
+
+__all__ = [
+ "BioGptForCausalLM",
+ "BioGptForTokenClassification",
+ "BioGptForSequenceClassification",
+ "BioGptModel",
+ "BioGptPreTrainedModel",
+]
diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py
index 8eb282ac6faf..da7282d388fb 100755
--- a/src/transformers/models/blenderbot/modeling_blenderbot.py
+++ b/src/transformers/models/blenderbot/modeling_blenderbot.py
@@ -18,7 +18,7 @@
import math
import os
import warnings
-from typing import List, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -31,7 +31,9 @@
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -39,7 +41,8 @@
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
is_torch_flex_attn_available,
@@ -51,9 +54,7 @@
if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask
-
- from ...integrations.flex_attention import make_flex_block_causal_mask
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -110,6 +111,37 @@ def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Blenderbot
class BlenderbotAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -162,17 +194,25 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
@@ -193,8 +233,8 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -206,69 +246,27 @@ def forward(
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
-
-
-BLENDERBOT_ATTENTION_CLASSES = {"eager": BlenderbotAttention}
+ return attn_output, attn_weights, past_key_value
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Blenderbot, MBART->BLENDERBOT
@@ -277,7 +275,7 @@ def __init__(self, config: BlenderbotConfig):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = BlenderbotAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -346,7 +344,7 @@ def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = BlenderbotAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -360,7 +358,7 @@ def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = BLENDERBOT_ATTENTION_CLASSES[config._attn_implementation](
+ self.encoder_attn = BlenderbotAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -428,7 +426,6 @@ def forward(
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
- # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
hidden_states, cross_attn_weights, past_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
@@ -465,6 +462,10 @@ class BlenderbotPreTrainedModel(PreTrainedModel):
config_class = BlenderbotConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
_supports_cache_class = True
_supports_static_cache = True
@@ -493,24 +494,56 @@ def dummy_inputs(self):
}
return dummy_inputs
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
def _update_causal_mask(
self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
- output_attentions: bool = False,
):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
return attention_mask
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
@@ -518,7 +551,7 @@ def _update_causal_mask(
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@@ -552,7 +585,6 @@ def _update_causal_mask(
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
@@ -618,6 +650,42 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
class BlenderbotEncoder(BlenderbotPreTrainedModel):
"""
@@ -730,10 +798,10 @@ def forward(
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- # expand attention_mask
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@@ -927,22 +995,28 @@ def forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
- ## retrieve input_ids and inputs_embeds
+ # retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-
- if input_ids is not None:
- input_ids = input_ids.view(-1, input_ids.shape[-1])
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
+ inputs_embeds = self.embed_tokens(input)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
# initialize `past_key_values`
return_legacy_cache = False
@@ -972,20 +1046,19 @@ def forward(
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
+
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
self_attn_cache,
- output_attentions,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length
- )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
# embed positions
position_ids = self.embed_positions(
diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
index 2f778d729395..2237907aa0e8 100755
--- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
+++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py
@@ -16,7 +16,7 @@
import copy
import math
-from typing import List, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -29,7 +29,9 @@
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -37,7 +39,8 @@
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
is_torch_flex_attn_available,
@@ -48,9 +51,7 @@
if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask
-
- from ...integrations.flex_attention import make_flex_block_causal_mask
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -94,6 +95,37 @@ def forward(
return super().forward(position_ids)
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BlenderbotSmall
class BlenderbotSmallAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -146,17 +178,25 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
@@ -177,8 +217,8 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -190,66 +230,27 @@ def forward(
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
+ return attn_output, attn_weights, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
@@ -258,7 +259,7 @@ def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = Non
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = BlenderbotSmallAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -324,19 +325,13 @@ def forward(
return outputs
-# TODO: Implement attention with SDPA for TimeSeriesTransformer.
-BLENDERBOT_SMALL_ATTENTION_CLASSES = {
- "eager": BlenderbotSmallAttention,
-}
-
-
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL
class BlenderbotSmallDecoderLayer(nn.Module):
def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = BlenderbotSmallAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -350,7 +345,7 @@ def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = Non
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = BLENDERBOT_SMALL_ATTENTION_CLASSES[config._attn_implementation](
+ self.encoder_attn = BlenderbotSmallAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -424,6 +419,7 @@ def forward(
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
+ cache_position=cache_position,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
@@ -454,6 +450,10 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel):
config_class = BlenderbotSmallConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
_supports_cache_class = True
_supports_static_cache = True
@@ -482,24 +482,56 @@ def dummy_inputs(self):
}
return dummy_inputs
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
def _update_causal_mask(
self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
- output_attentions: bool = False,
):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
return attention_mask
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
@@ -507,7 +539,7 @@ def _update_causal_mask(
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@@ -541,7 +573,6 @@ def _update_causal_mask(
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
@@ -607,6 +638,42 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel):
"""
@@ -718,10 +785,10 @@ def forward(
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- # expand attention_mask
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@@ -909,24 +976,28 @@ def forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-
- if input_ids is not None:
- input_ids = input_ids.view(-1, input_ids.shape[-1])
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
+ inputs_embeds = self.embed_tokens(input)
- inputs_embeds = inputs_embeds * self.embed_scale
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
# initialize `past_key_values`
return_legacy_cache = False
@@ -956,20 +1027,19 @@ def forward(
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
+
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
self_attn_cache,
- output_attentions,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length
- )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
# embed positions
position_ids = self.embed_positions(
diff --git a/src/transformers/models/data2vec/modeling_data2vec_audio.py b/src/transformers/models/data2vec/modeling_data2vec_audio.py
index e1a822ea0377..d9046ea6e8c9 100755
--- a/src/transformers/models/data2vec/modeling_data2vec_audio.py
+++ b/src/transformers/models/data2vec/modeling_data2vec_audio.py
@@ -4,9 +4,24 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_data2vec_audio.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
import warnings
-from typing import Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
@@ -16,7 +31,8 @@
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
-from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
@@ -25,16 +41,14 @@
Wav2Vec2BaseModelOutput,
XVectorOutput,
)
-from ...modeling_utils import PreTrainedModel
-from ...utils import auto_docstring, is_peft_available, logging
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available
from .configuration_data2vec_audio import Data2VecAudioConfig
-if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
-
-
-logger = logging.get_logger(__name__)
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
class Data2VecAudioConvLayer(nn.Module):
@@ -167,6 +181,36 @@ def forward(self, hidden_states):
return hidden_states, norm_hidden_states
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
class Data2VecAudioAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -201,9 +245,6 @@ def __init__(
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -212,6 +253,9 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -219,267 +263,16 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped, past_key_value
-
-
-class Data2VecAudioFlashAttention2(Data2VecAudioAttention):
- """
- Data2VecAudio flash attention module. This module inherits from `Data2VecAudioAttention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
-
- def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
- bsz, q_len, _ = hidden_states.size()
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0].transpose(1, 2)
- value_states = past_key_value[1].transpose(1, 2)
- elif is_cross_attention:
- # cross_attentions
- key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
- value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
- else:
- # self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (LlamaRMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=self.dropout if self.training else 0.0,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1)
- attn_output = self.out_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class Data2VecAudioSdpaAttention(Data2VecAudioAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "Data2VecAudioModel is using Data2VecAudioSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -494,18 +287,18 @@ def forward(
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -517,39 +310,27 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
- query_states = self._shape(query_states, tgt_len, bsz)
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
- is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
- attn_output = torch.nn.functional.scaled_dot_product_attention(
+ attn_output, attn_weights = attention_interface(
+ self,
query_states,
key_states,
value_states,
- attn_mask=attention_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
)
- if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, None, past_key_value
+ return attn_output, attn_weights, past_key_value
class Data2VecAudioFeedForward(nn.Module):
@@ -576,21 +357,15 @@ def forward(self, hidden_states):
return hidden_states
-DATA2VEC_AUDIO_ATTENTION_CLASSES = {
- "eager": Data2VecAudioAttention,
- "sdpa": Data2VecAudioSdpaAttention,
- "flash_attention_2": Data2VecAudioFlashAttention2,
-}
-
-
class Data2VecAudioEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
- self.attention = DATA2VEC_AUDIO_ATTENTION_CLASSES[config._attn_implementation](
+ self.attention = Data2VecAudioAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
+ config=config,
)
self.dropout = nn.Dropout(config.hidden_dropout)
@@ -627,7 +402,6 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([Data2VecAudioEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
def forward(
self,
@@ -644,16 +418,11 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- else:
- # extend attention_mask
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
- attention_mask = attention_mask.expand(
- attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
- )
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@@ -702,6 +471,28 @@ def forward(
attentions=all_self_attentions,
)
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
class Data2VecAudioAdapterLayer(nn.Module):
def __init__(self, config):
@@ -760,6 +551,8 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/data2vec/modular_data2vec_audio.py b/src/transformers/models/data2vec/modular_data2vec_audio.py
index 58934d2e86a7..73a42937bd8e 100644
--- a/src/transformers/models/data2vec/modular_data2vec_audio.py
+++ b/src/transformers/models/data2vec/modular_data2vec_audio.py
@@ -1,3 +1,19 @@
+# coding=utf-8
+# Copyright 2022 The HuggingFace Inc. team.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Data2VecText model."""
+
import math
import torch
@@ -124,6 +140,8 @@ class Data2VecAudioPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py
index d920e998f97f..eb366963a674 100755
--- a/src/transformers/models/hubert/modeling_hubert.py
+++ b/src/transformers/models/hubert/modeling_hubert.py
@@ -4,8 +4,23 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_hubert.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import warnings
-from typing import Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
@@ -15,15 +30,17 @@
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
-from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
-from ...modeling_utils import PreTrainedModel
-from ...utils import auto_docstring, logging
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from .configuration_hubert import HubertConfig
-if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -224,6 +241,36 @@ def forward(self, hidden_states):
return hidden_states
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
class HubertAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -258,9 +305,6 @@ def __init__(
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -269,6 +313,9 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -276,267 +323,16 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped, past_key_value
-
-
-class HubertFlashAttention2(HubertAttention):
- """
- Hubert flash attention module. This module inherits from `HubertAttention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
-
- def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, q_len, _ = hidden_states.size()
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0].transpose(1, 2)
- value_states = past_key_value[1].transpose(1, 2)
- elif is_cross_attention:
- # cross_attentions
- key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
- value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
- else:
- # self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (LlamaRMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=self.dropout if self.training else 0.0,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1)
- attn_output = self.out_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class HubertSdpaAttention(HubertAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "HubertModel is using HubertSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -551,18 +347,18 @@ def forward(
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -574,39 +370,27 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
- query_states = self._shape(query_states, tgt_len, bsz)
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
- is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
- attn_output = torch.nn.functional.scaled_dot_product_attention(
+ attn_output, attn_weights = attention_interface(
+ self,
query_states,
key_states,
value_states,
- attn_mask=attention_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
)
- if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, None, past_key_value
+ return attn_output, attn_weights, past_key_value
class HubertFeedForward(nn.Module):
@@ -633,21 +417,15 @@ def forward(self, hidden_states):
return hidden_states
-HUBERT_ATTENTION_CLASSES = {
- "eager": HubertAttention,
- "sdpa": HubertSdpaAttention,
- "flash_attention_2": HubertFlashAttention2,
-}
-
-
class HubertEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
- self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation](
+ self.attention = HubertAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
+ config=config,
)
self.dropout = nn.Dropout(config.hidden_dropout)
@@ -684,7 +462,6 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([HubertEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
def forward(
self,
@@ -701,16 +478,11 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- else:
- # extend attention_mask
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
- attention_mask = attention_mask.expand(
- attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
- )
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@@ -759,6 +531,28 @@ def forward(
attentions=all_self_attentions,
)
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
class HubertAttnAdapterLayer(nn.Module):
def __init__(self, config):
@@ -788,11 +582,12 @@ def forward(self, hidden_states: torch.FloatTensor):
class HubertEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
- self.attention = HUBERT_ATTENTION_CLASSES[config._attn_implementation](
+ self.attention = HubertAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
+ config=config,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -841,7 +636,6 @@ def __init__(self, config):
[HubertEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
def forward(
self,
@@ -855,19 +649,14 @@ def forward(
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
- # make sure padded tokens are not attended to
+ # make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype)
- if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- else:
- # extend attention_mask
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
- attention_mask = attention_mask.expand(
- attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
- )
+ hidden_states[~expand_attention_mask] = 0
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@@ -918,6 +707,28 @@ def forward(
attentions=all_self_attentions,
)
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
@auto_docstring
class HubertPreTrainedModel(PreTrainedModel):
@@ -927,6 +738,8 @@ class HubertPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/hubert/modular_hubert.py b/src/transformers/models/hubert/modular_hubert.py
index b3e3d24cc0e8..75000c95cb38 100644
--- a/src/transformers/models/hubert/modular_hubert.py
+++ b/src/transformers/models/hubert/modular_hubert.py
@@ -1,3 +1,19 @@
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Hubert model."""
+
from typing import Optional, Tuple, Union
import torch
@@ -115,6 +131,8 @@ class HubertPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py
index 8b728a19dff7..330bc620bc03 100644
--- a/src/transformers/models/informer/modeling_informer.py
+++ b/src/transformers/models/informer/modeling_informer.py
@@ -1,3 +1,9 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/informer/modular_informer.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_informer.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved.
#
@@ -12,19 +18,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""PyTorch Informer model."""
-from typing import List, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
from torch import nn
from ...activations import ACT2FN
+from ...cache_utils import Cache, EncoderDecoderCache
from ...modeling_attn_mask_utils import (
_prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -32,19 +41,20 @@
Seq2SeqTSModelOutput,
Seq2SeqTSPredictionOutput,
)
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
-from ...utils import (
- auto_docstring,
- logging,
-)
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from .configuration_informer import InformerConfig
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
logger = logging.get_logger(__name__)
-# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesFeatureEmbedder with TimeSeries->Informer
class InformerFeatureEmbedder(nn.Module):
"""
Embed a sequence of categorical features.
@@ -79,7 +89,6 @@ def forward(self, features: torch.Tensor) -> torch.Tensor:
)
-# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesStdScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer
class InformerStdScaler(nn.Module):
"""
Standardize features by calculating the mean and scaling along the first dimension, and then normalizes it by
@@ -115,7 +124,6 @@ def forward(
return (data - loc) / scale, loc, scale
-# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesMeanScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer
class InformerMeanScaler(nn.Module):
"""
Computes a scaling factor as the weighted average absolute value along the first dimension, and scales the data
@@ -170,7 +178,6 @@ def forward(
return scaled_data, torch.zeros_like(scale), scale
-# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesNOPScaler with TimeSeriesTransformer->Informer,TimeSeries->Informer
class InformerNOPScaler(nn.Module):
"""
Assigns a scaling factor equal to 1 along the first dimension, and therefore applies no scaling to the input data.
@@ -198,40 +205,6 @@ def forward(
return data, loc, scale
-# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.weighted_average
-def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
- """
- Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
- meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
-
- Args:
- input_tensor (`torch.FloatTensor`):
- Input tensor, of which the average must be computed.
- weights (`torch.FloatTensor`, *optional*):
- Weights tensor, of the same shape as `input_tensor`.
- dim (`int`, *optional*):
- The dim along which to average `input_tensor`.
-
- Returns:
- `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
- """
- if weights is not None:
- weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
- sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
- return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
- else:
- return input_tensor.mean(dim=dim)
-
-
-# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.nll
-def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
- """
- Computes the negative log likelihood loss from input distribution with respect to target.
- """
- return -input.log_prob(target)
-
-
-# Copied from transformers.models.marian.modeling_marian.MarianSinusoidalPositionalEmbedding with Marian->Informer
class InformerSinusoidalPositionalEmbedding(nn.Embedding):
"""This module produces sinusoidal positional embeddings of any length."""
@@ -266,7 +239,6 @@ def forward(
return super().forward(position_ids)
-# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesValueEmbedding with TimeSeries->Info
class InformerValueEmbedding(nn.Module):
def __init__(self, feature_size, d_model):
super().__init__()
@@ -276,7 +248,156 @@ def forward(self, x):
return self.value_projection(x)
-# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Informer
+@auto_docstring
+class InformerPreTrainedModel(PreTrainedModel):
+ config_class = InformerConfig
+ base_model_prefix = "model"
+ main_input_name = "past_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, InformerSinusoidalPositionalEmbedding):
+ module._init_weight()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_shape),
+ device=inputs_embeds.device,
+ )
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ return attention_mask
+
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
+
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
class InformerAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -289,6 +410,7 @@ def __init__(
bias: bool = True,
is_causal: bool = False,
config: Optional[InformerConfig] = None,
+ layer_idx: Optional[int] = None,
):
super().__init__()
self.embed_dim = embed_dim
@@ -305,23 +427,31 @@ def __init__(
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
+ self.layer_idx = layer_idx
+ if layer_idx is None and self.is_decoder:
+ logger.warning_once(
+ f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
+ "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
+ "when creating this class."
+ )
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ past_key_value: Optional[Cache] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
+ cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -329,110 +459,69 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
+ if past_key_value is not None:
+ if isinstance(past_key_value, EncoderDecoderCache):
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = past_key_value.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_value.self_attention_cache
+ else:
+ curr_past_key_value = past_key_value
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ key_states = curr_past_key_value.key_cache[self.layer_idx]
+ value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ key_states = self.k_proj(current_states)
+ value_states = self.v_proj(current_states)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
+
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
)
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention:
+ past_key_value.is_updated[self.layer_idx] = True
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
+ return attn_output, attn_weights, past_key_value
class InformerProbSparseAttention(nn.Module):
@@ -448,6 +537,7 @@ def __init__(
is_decoder: bool = False,
sampling_factor: int = 5,
bias: bool = True,
+ layer_idx: Optional[int] = None,
):
super().__init__()
self.factor = sampling_factor
@@ -463,6 +553,7 @@ def __init__(
)
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
+ self.layer_idx = layer_idx
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
@@ -480,6 +571,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
+ cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -488,45 +580,43 @@ def forward(
is_cross_attention = key_value_states is not None
bsz, tgt_len, _ = hidden_states.size()
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
+
+ if past_key_value is not None:
+ if isinstance(past_key_value, EncoderDecoderCache):
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = past_key_value.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_value.self_attention_cache
+ else:
+ curr_past_key_value = past_key_value
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_value is not None and is_updated:
# reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ key_states = curr_past_key_value.key_cache[self.layer_idx]
+ value_states = curr_past_key_value.value_cache[self.layer_idx]
else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
+ key_states = self.k_proj(current_states)
+ value_states = self.v_proj(current_states)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
+
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention:
+ past_key_value.is_updated[self.layer_idx] = True
proj_shape = (bsz * self.num_heads, -1, self.head_dim)
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
@@ -681,6 +771,14 @@ class InformerEncoderLayer(nn.Module):
def __init__(self, config: InformerConfig):
super().__init__()
self.embed_dim = config.d_model
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+ self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
+ self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
if config.attention_type == "prob":
self.self_attn = InformerProbSparseAttention(
embed_dim=self.embed_dim,
@@ -693,14 +791,8 @@ def __init__(self, config: InformerConfig):
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
+ config=config,
)
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.dropout = config.dropout
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
- self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim)
- self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim)
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
@@ -754,9 +846,26 @@ def forward(
class InformerDecoderLayer(nn.Module):
- def __init__(self, config: InformerConfig):
+ def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = InformerAttention(
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
if config.attention_type == "prob":
self.self_attn = InformerProbSparseAttention(
@@ -765,6 +874,7 @@ def __init__(self, config: InformerConfig):
dropout=config.attention_dropout,
sampling_factor=config.sampling_factor,
is_decoder=True,
+ layer_idx=layer_idx,
)
else:
self.self_attn = InformerAttention(
@@ -772,22 +882,9 @@ def __init__(self, config: InformerConfig):
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
+ config=config,
+ layer_idx=layer_idx,
)
- self.dropout = config.dropout
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
-
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = InformerAttention(
- self.embed_dim,
- config.decoder_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=True,
- )
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
def forward(
self,
@@ -797,9 +894,10 @@ def forward(
encoder_attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = True,
+ cache_position: Optional[torch.Tensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
@@ -818,47 +916,43 @@ def forward(
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
+ cache in the correct position and to infer the complete sequence length.
"""
residual = hidden_states
# Self Attention
- # decoder uni-directional self-attention cached key/values tuple is at positions 1,2
- self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
- # add present self-attn cache to positions 1,2 of present_key_value tuple
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
+ hidden_states, self_attn_weights, past_key_value = self.self_attn(
hidden_states=hidden_states,
- past_key_value=self_attn_past_key_value,
+ past_key_value=past_key_value,
attention_mask=attention_mask,
layer_head_mask=layer_head_mask,
output_attentions=output_attentions,
+ cache_position=cache_position,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.self_attn_layer_norm(hidden_states)
# Cross-Attention Block
- cross_attn_present_key_value = None
cross_attn_weights = None
if encoder_hidden_states is not None:
residual = hidden_states
- # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
- cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None
- hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn(
+ hidden_states, cross_attn_weights, past_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
layer_head_mask=cross_attn_layer_head_mask,
- past_key_value=cross_attn_past_key_value,
+ past_key_value=past_key_value,
output_attentions=output_attentions,
+ cache_position=cache_position,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
- # add cross-attn to positions 3,4 of present_key_value tuple
- present_key_value = present_key_value + cross_attn_present_key_value
-
# Fully Connected
residual = hidden_states
hidden_states = self.activation_fn(self.fc1(hidden_states))
@@ -874,36 +968,15 @@ def forward(
outputs += (self_attn_weights, cross_attn_weights)
if use_cache:
- outputs += (present_key_value,)
+ outputs += (past_key_value,)
return outputs
-@auto_docstring
-class InformerPreTrainedModel(PreTrainedModel):
- config_class = InformerConfig
- base_model_prefix = "model"
- main_input_name = "past_values"
- supports_gradient_checkpointing = True
-
- def _init_weights(self, module):
- std = self.config.init_std
- if isinstance(module, (nn.Linear, nn.Conv1d)):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, InformerSinusoidalPositionalEmbedding):
- module._init_weight()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
-
-
class InformerEncoder(InformerPreTrainedModel):
"""
- Informer encoder consisting of *config.encoder_layers* self attention layers with distillation layers. Each
- attention layer is an [`InformerEncoderLayer`].
+ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
+ [`InformerEncoderLayer`].
Args:
config: InformerConfig
@@ -914,7 +987,6 @@ def __init__(self, config: InformerConfig):
self.dropout = config.dropout
self.layerdrop = config.encoder_layerdrop
- self.gradient_checkpointing = False
if config.prediction_length is None:
raise ValueError("The `prediction_length` config needs to be specified.")
@@ -924,6 +996,7 @@ def __init__(self, config: InformerConfig):
)
self.layers = nn.ModuleList([InformerEncoderLayer(config) for _ in range(config.encoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
+ self.gradient_checkpointing = False
if config.distil:
self.conv_layers = nn.ModuleList(
@@ -932,7 +1005,6 @@ def __init__(self, config: InformerConfig):
self.conv_layers.append(None)
else:
self.conv_layers = [None] * config.encoder_layers
-
# Initialize weights and apply final processing
self.post_init()
@@ -1053,7 +1125,7 @@ def forward(
class InformerDecoder(InformerPreTrainedModel):
"""
- Informer decoder consisting of *config.decoder_layers* layers. Each layer is a
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a
[`InformerDecoderLayer`]
Args:
@@ -1071,7 +1143,7 @@ def __init__(self, config: InformerConfig):
self.embed_positions = InformerSinusoidalPositionalEmbedding(
config.context_length + config.prediction_length, config.d_model
)
- self.layers = nn.ModuleList([InformerDecoderLayer(config) for _ in range(config.decoder_layers)])
+ self.layers = nn.ModuleList([InformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
self.layernorm_embedding = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
@@ -1091,6 +1163,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]:
r"""
Args:
@@ -1148,6 +1221,9 @@ def forward(
for more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
+ cache in the correct position and to infer the complete sequence length.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
@@ -1157,20 +1233,35 @@ def forward(
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
input_shape = inputs_embeds.size()[:-1]
+ # initialize `past_key_values`
+ return_legacy_cache = False
+ if use_cache and not isinstance(past_key_values, Cache):
+ return_legacy_cache = True
+ logger.warning_once(
+ "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
+ "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
+ "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
+ )
+ past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
- # past_key_values_length
- past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
+ past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device
+ )
- attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
+ attention_mask = self._update_causal_mask(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
- )
hidden_states = self.value_embedding(inputs_embeds)
embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length)
@@ -1188,7 +1279,7 @@ def forward(
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
- next_decoder_cache = () if use_cache else None
+ next_decoder_cache = None
# check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired
for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]):
@@ -1208,8 +1299,6 @@ def forward(
if dropout_probability < self.layerdrop:
continue
- past_key_value = past_key_values[idx] if past_key_values is not None else None
-
if self.gradient_checkpointing and self.training:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
@@ -1222,6 +1311,7 @@ def forward(
None,
output_attentions,
use_cache,
+ cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -1233,14 +1323,15 @@ def forward(
cross_attn_layer_head_mask=(
cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None
),
- past_key_value=past_key_value,
+ past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
+ cache_position=cache_position,
)
hidden_states = layer_outputs[0]
if use_cache:
- next_decoder_cache += (layer_outputs[3 if output_attentions else 1],)
+ next_decoder_cache = layer_outputs[3 if output_attentions else 1]
if output_attentions:
all_self_attns += (layer_outputs[1],)
@@ -1253,6 +1344,9 @@ def forward(
all_hidden_states += (hidden_states,)
next_cache = next_decoder_cache if use_cache else None
+ if return_legacy_cache:
+ next_cache = past_key_values.to_legacy_cache()
+
if not return_dict:
return tuple(
v
@@ -1269,7 +1363,6 @@ def forward(
@auto_docstring
-# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerModel with TimeSeriesTransformer->Informer,TIME_SERIES_TRANSFORMER->INFORMER,time-series-transformer->informer,TimeSeries->Informer
class InformerModel(InformerPreTrainedModel):
def __init__(self, config: InformerConfig):
super().__init__(config)
@@ -1408,7 +1501,6 @@ def get_encoder(self):
def get_decoder(self):
return self.decoder
- # Ignore copy
@auto_docstring
def forward(
self,
@@ -1429,6 +1521,7 @@ def forward(
output_attentions: Optional[bool] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
) -> Union[Seq2SeqTSModelOutput, Tuple]:
r"""
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`):
@@ -1586,7 +1679,16 @@ def forward(
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
- dec_input = transformer_inputs[:, self.config.context_length :, ...]
+ # Avoid empty tensors and instead create a zeroes tensor which
+ # will be treated the same in torch, i.e. matmul with empty == all 0s
+ if self.config.context_length >= transformer_inputs.shape[1]:
+ bsz, _, dim = transformer_inputs.shape
+ dec_input = torch.zeros(
+ size=(bsz, 1, dim), device=transformer_inputs.device, dtype=transformer_inputs.dtype
+ )
+ else:
+ dec_input = transformer_inputs[:, self.config.context_length :, ...]
+
decoder_outputs = self.decoder(
inputs_embeds=dec_input,
attention_mask=decoder_attention_mask,
@@ -1598,6 +1700,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
+ cache_position=cache_position,
)
if not return_dict:
@@ -1618,11 +1721,42 @@ def forward(
)
+def weighted_average(input_tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim=None) -> torch.Tensor:
+ """
+ Computes the weighted average of a given tensor across a given `dim`, masking values associated with weight zero,
+ meaning instead of `nan * 0 = nan` you will get `0 * 0 = 0`.
+
+ Args:
+ input_tensor (`torch.FloatTensor`):
+ Input tensor, of which the average must be computed.
+ weights (`torch.FloatTensor`, *optional*):
+ Weights tensor, of the same shape as `input_tensor`.
+ dim (`int`, *optional*):
+ The dim along which to average `input_tensor`.
+
+ Returns:
+ `torch.FloatTensor`: The tensor with values averaged along the specified `dim`.
+ """
+ if weights is not None:
+ weighted_tensor = torch.where(weights != 0, input_tensor * weights, torch.zeros_like(input_tensor))
+ sum_weights = torch.clamp(weights.sum(dim=dim) if dim else weights.sum(), min=1.0)
+ return (weighted_tensor.sum(dim=dim) if dim else weighted_tensor.sum()) / sum_weights
+ else:
+ return input_tensor.mean(dim=dim)
+
+
+def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the negative log likelihood loss from input distribution with respect to target.
+ """
+ return -input.log_prob(target)
+
+
@auto_docstring
-# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerForPrediction with TimeSeriesTransformer->Informer,TIME_SERIES_TRANSFORMER->INFORMER,time-series-transformer->informer
class InformerForPrediction(InformerPreTrainedModel):
def __init__(self, config: InformerConfig):
super().__init__(config)
+
self.model = InformerModel(config)
if config.distribution_output == "student_t":
self.distribution_output = StudentTOutput(dim=config.input_size)
@@ -1660,7 +1794,6 @@ def output_distribution(self, params, loc=None, scale=None, trailing_n=None) ->
sliced_params = [p[:, -trailing_n:] for p in params]
return self.distribution_output.distribution(sliced_params, loc=loc, scale=scale)
- # Ignore copy
@auto_docstring
def forward(
self,
@@ -1682,6 +1815,7 @@ def forward(
output_attentions: Optional[bool] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
) -> Union[Seq2SeqTSModelOutput, Tuple]:
r"""
past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`):
@@ -1853,6 +1987,7 @@ def forward(
output_attentions=output_attentions,
use_cache=use_cache,
return_dict=return_dict,
+ cache_position=cache_position,
)
prediction_loss = None
diff --git a/src/transformers/models/informer/modular_informer.py b/src/transformers/models/informer/modular_informer.py
new file mode 100644
index 000000000000..15bcb8d38a83
--- /dev/null
+++ b/src/transformers/models/informer/modular_informer.py
@@ -0,0 +1,997 @@
+# coding=utf-8
+# Copyright 2023 Amazon and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch Informer model."""
+
+from typing import Optional, Tuple, Union
+
+import numpy as np
+import torch
+from torch import nn
+
+from ...cache_utils import EncoderDecoderCache
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from ...modeling_outputs import (
+ BaseModelOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
+from ...utils import (
+ auto_docstring,
+ is_torch_flex_attn_available,
+)
+from ..bart.modeling_bart import BartAttention
+from ..time_series_transformer.modeling_time_series_transformer import (
+ TimeSeriesFeatureEmbedder,
+ TimeSeriesMeanScaler,
+ TimeSeriesNOPScaler,
+ TimeSeriesSinusoidalPositionalEmbedding,
+ TimeSeriesStdScaler,
+ TimeSeriesTransformerDecoder,
+ TimeSeriesTransformerDecoderLayer,
+ TimeSeriesTransformerEncoder,
+ TimeSeriesTransformerEncoderLayer,
+ TimeSeriesTransformerForPrediction,
+ TimeSeriesTransformerModel,
+ TimeSeriesValueEmbedding,
+)
+from .configuration_informer import InformerConfig
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
+def nll(input: torch.distributions.Distribution, target: torch.Tensor) -> torch.Tensor:
+ """
+ Computes the negative log likelihood loss from input distribution with respect to target.
+ """
+ return -input.log_prob(target)
+
+
+class InformerFeatureEmbedder(TimeSeriesFeatureEmbedder):
+ pass
+
+
+class InformerStdScaler(TimeSeriesStdScaler):
+ pass
+
+
+class InformerMeanScaler(TimeSeriesMeanScaler):
+ pass
+
+
+class InformerNOPScaler(TimeSeriesNOPScaler):
+ pass
+
+
+class InformerSinusoidalPositionalEmbedding(TimeSeriesSinusoidalPositionalEmbedding):
+ pass
+
+
+class InformerValueEmbedding(TimeSeriesValueEmbedding):
+ pass
+
+
+@auto_docstring
+class InformerPreTrainedModel(PreTrainedModel):
+ config_class = InformerConfig
+ base_model_prefix = "model"
+ main_input_name = "past_values"
+ supports_gradient_checkpointing = True
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, (nn.Linear, nn.Conv1d)):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, InformerSinusoidalPositionalEmbedding):
+ module._init_weight()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_shape),
+ device=inputs_embeds.device,
+ )
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ return attention_mask
+
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
+
+class InformerAttention(BartAttention):
+ pass
+
+
+class InformerProbSparseAttention(nn.Module):
+ """Probabilistic Attention mechanism to select the "active"
+ queries rather than the "lazy" queries and provides a sparse Transformer thus mitigating the quadratic compute and
+ memory requirements of vanilla attention"""
+
+ def __init__(
+ self,
+ embed_dim: int,
+ num_heads: int,
+ dropout: float = 0.0,
+ is_decoder: bool = False,
+ sampling_factor: int = 5,
+ bias: bool = True,
+ layer_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ self.factor = sampling_factor
+ self.embed_dim = embed_dim
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.head_dim = embed_dim // num_heads
+
+ if (self.head_dim * num_heads) != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
+ f" and `num_heads`: {num_heads})."
+ )
+ self.scaling = self.head_dim**-0.5
+ self.is_decoder = is_decoder
+ self.layer_idx = layer_idx
+
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ key_value_states: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ output_attentions: bool = False,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ """Input shape: Batch x Time x Channel"""
+
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
+
+ bsz, tgt_len, _ = hidden_states.size()
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
+
+ # get query proj
+ query_states = self.q_proj(hidden_states) * self.scaling
+
+ if past_key_value is not None:
+ if isinstance(past_key_value, EncoderDecoderCache):
+ is_updated = past_key_value.is_updated.get(self.layer_idx)
+ if is_cross_attention:
+ # after the first generated id, we can subsequently re-use all key/value_states from cache
+ curr_past_key_value = past_key_value.cross_attention_cache
+ else:
+ curr_past_key_value = past_key_value.self_attention_cache
+ else:
+ curr_past_key_value = past_key_value
+
+ current_states = key_value_states if is_cross_attention else hidden_states
+ if is_cross_attention and past_key_value is not None and is_updated:
+ # reuse k,v, cross_attentions
+ key_states = curr_past_key_value.key_cache[self.layer_idx]
+ value_states = curr_past_key_value.value_cache[self.layer_idx]
+ else:
+ key_states = self.k_proj(current_states)
+ value_states = self.v_proj(current_states)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
+
+ if past_key_value is not None:
+ # save all key/value_states to cache to be re-used for fast auto-regressive generation
+ cache_position = cache_position if not is_cross_attention else None
+ key_states, value_states = curr_past_key_value.update(
+ key_states, value_states, self.layer_idx, {"cache_position": cache_position}
+ )
+ # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
+ if is_cross_attention:
+ past_key_value.is_updated[self.layer_idx] = True
+
+ proj_shape = (bsz * self.num_heads, -1, self.head_dim)
+ query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
+ key_states = key_states.reshape(*proj_shape)
+ value_states = value_states.reshape(*proj_shape)
+
+ key_states_time_length = key_states.size(1) # L_K
+ log_key_states_time_length = np.ceil(np.log1p(key_states_time_length)).astype("int").item() # log_L_K
+
+ query_states_time_length = query_states.size(1) # L_Q
+ log_query_states_time_length = np.ceil(np.log1p(query_states_time_length)).astype("int").item() # log_L_Q
+
+ u_part = min(self.factor * query_states_time_length * log_key_states_time_length, key_states_time_length)
+ u = min(self.factor * log_query_states_time_length, query_states_time_length)
+
+ if key_states_time_length > 0:
+ index_sample = torch.randint(0, key_states_time_length, (u_part,))
+ k_sample = key_states[:, index_sample, :]
+ else:
+ k_sample = key_states
+
+ queries_keys_sample = torch.bmm(query_states, k_sample.transpose(1, 2)) # Q_K_sampled
+
+ # find the Top_k query with sparsity measurement
+ if u > 0:
+ sparsity_measurement = queries_keys_sample.max(dim=-1)[0] - torch.div(
+ queries_keys_sample.sum(dim=-1), key_states_time_length
+ ) # M
+ top_u_sparsity_measurement = sparsity_measurement.topk(u, sorted=False)[1] # M_top
+
+ # calculate q_reduce: query_states[:, top_u_sparsity_measurement]
+ dim_for_slice = torch.arange(query_states.size(0)).unsqueeze(-1)
+ q_reduce = query_states[dim_for_slice, top_u_sparsity_measurement]
+ else:
+ q_reduce = query_states
+ top_u_sparsity_measurement = None
+
+ # Use q_reduce to calculate attention weights
+ attn_weights = torch.bmm(q_reduce, key_states.transpose(1, 2))
+
+ src_len = key_states.size(1)
+ if attn_weights.size() != (bsz * self.num_heads, u, src_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz * self.num_heads, u, src_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, tgt_len, src_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
+ )
+ prob_mask = attention_mask.expand(bsz, self.num_heads, tgt_len, src_len).reshape(
+ bsz * self.num_heads, tgt_len, src_len
+ )
+
+ if top_u_sparsity_measurement is not None:
+ dim_for_slice = torch.arange(prob_mask.size(0)).unsqueeze(-1)
+ prob_mask = prob_mask[dim_for_slice, top_u_sparsity_measurement, :]
+
+ attn_weights = attn_weights.view(bsz, self.num_heads, u, src_len) + prob_mask.view(
+ bsz, self.num_heads, u, src_len
+ )
+ attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len)
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if layer_head_mask is not None:
+ if layer_head_mask.size() != (self.num_heads,):
+ raise ValueError(
+ f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
+ f" {layer_head_mask.size()}"
+ )
+ attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, u, src_len)
+ attn_weights = attn_weights.view(bsz * self.num_heads, u, src_len)
+
+ if output_attentions:
+ # this operation is a bit awkward, but it's required to
+ # make sure that attn_weights keeps its gradient.
+ # In order to do so, attn_weights have to be reshaped
+ # twice and have to be reused in the following
+ attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, u, src_len)
+ attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, u, src_len)
+ else:
+ attn_weights_reshaped = None
+
+ attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
+ attn_output = torch.bmm(attn_probs, value_states)
+
+ # calculate context for updating the attn_output, based on:
+ # https://github.com/zhouhaoyi/Informer2020/blob/ac59c7447135473fb2aafeafe94395f884d5c7a5/models/attn.py#L74
+ if self.is_decoder:
+ # cast to float32 before operation to avoid overflow
+ context = value_states.cumsum(dim=-2, dtype=torch.float32).to(value_states.dtype)
+ else:
+ v_mean_dim_time = value_states.mean(dim=-2)
+ context = (
+ v_mean_dim_time.unsqueeze(dim=1)
+ .expand(bsz * self.num_heads, query_states_time_length, v_mean_dim_time.size(-1))
+ .clone()
+ )
+
+ if top_u_sparsity_measurement is not None:
+ # update context: copy the attention output to the context at top_u_sparsity_measurement index
+ dim_for_slice = torch.arange(context.size(0)).unsqueeze(-1)
+ context[dim_for_slice, top_u_sparsity_measurement, :] = attn_output
+ attn_output = context
+
+ if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+
+ attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
+ attn_output = attn_output.transpose(1, 2)
+
+ # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
+ # partitioned across GPUs when using tensor-parallelism.
+ attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+
+ attn_output = self.out_proj(attn_output)
+
+ return attn_output, attn_weights_reshaped, past_key_value
+
+
+# source: https://github.com/zhouhaoyi/Informer2020/blob/main/models/encoder.py
+class InformerConvLayer(nn.Module):
+ def __init__(self, c_in):
+ super().__init__()
+ self.downConv = nn.Conv1d(
+ in_channels=c_in,
+ out_channels=c_in,
+ kernel_size=3,
+ padding=1,
+ padding_mode="circular",
+ )
+ self.norm = nn.BatchNorm1d(c_in)
+ self.activation = nn.ELU()
+ self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
+
+ def forward(self, x):
+ x = self.downConv(x.permute(0, 2, 1))
+ x = self.norm(x)
+ x = self.activation(x)
+ x = self.maxPool(x)
+ x = x.transpose(1, 2)
+ return x
+
+
+class InformerEncoderLayer(TimeSeriesTransformerEncoderLayer):
+ def __init__(self, config: InformerConfig):
+ super().__init__(config)
+
+ del self.self_attn
+
+ if config.attention_type == "prob":
+ self.self_attn = InformerProbSparseAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ dropout=config.attention_dropout,
+ sampling_factor=config.sampling_factor,
+ )
+ else:
+ self.self_attn = InformerAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.encoder_attention_heads,
+ dropout=config.attention_dropout,
+ config=config,
+ )
+
+
+class InformerDecoderLayer(TimeSeriesTransformerDecoderLayer):
+ def __init__(self, config: InformerConfig, layer_idx: Optional[int] = None):
+ super().__init__(config)
+
+ del self.self_attn
+
+ if config.attention_type == "prob":
+ self.self_attn = InformerProbSparseAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ sampling_factor=config.sampling_factor,
+ is_decoder=True,
+ layer_idx=layer_idx,
+ )
+ else:
+ self.self_attn = InformerAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ config=config,
+ layer_idx=layer_idx,
+ )
+
+
+class InformerEncoder(TimeSeriesTransformerEncoder):
+ def __init__(self, config: InformerConfig):
+ super().__init__(config)
+
+ self.dropout = config.dropout
+ self.layerdrop = config.encoder_layerdrop
+ self.gradient_checkpointing = False
+ if config.prediction_length is None:
+ raise ValueError("The `prediction_length` config needs to be specified.")
+
+ self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model)
+ self.embed_positions = InformerSinusoidalPositionalEmbedding(
+ config.context_length + config.prediction_length, config.d_model
+ )
+ self.layers = nn.ModuleList([InformerEncoderLayer(config) for _ in range(config.encoder_layers)])
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
+
+ if config.distil:
+ self.conv_layers = nn.ModuleList(
+ [InformerConvLayer(config.d_model) for _ in range(config.encoder_layers - 1)]
+ )
+ self.conv_layers.append(None)
+ else:
+ self.conv_layers = [None] * config.encoder_layers
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(
+ self,
+ attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ Args:
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
+
+ - 1 for tokens that are **not masked**,
+ - 0 for tokens that are **masked**.
+
+ [What are attention masks?](../glossary#attention-mask)
+ head_mask (`torch.Tensor` of shape `(encoder_layers, encoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation.
+ This is useful if you want more control over how to convert `input_ids` indices into associated vectors
+ than the model's internal embedding lookup matrix.
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ output_hidden_states (`bool`, *optional*):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more detail.
+ return_dict (`bool`, *optional*):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ hidden_states = self.value_embedding(inputs_embeds)
+ embed_pos = self.embed_positions(inputs_embeds.size())
+
+ hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+
+ # expand attention_mask
+ if attention_mask is not None:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ encoder_states = () if output_hidden_states else None
+ all_attentions = () if output_attentions else None
+
+ # check if head_mask has a correct number of layers specified if desired
+ if head_mask is not None:
+ if head_mask.size()[0] != (len(self.layers)):
+ raise ValueError(
+ f"The head_mask should be specified for {len(self.layers)} layers, but it is for"
+ f" {head_mask.size()[0]}."
+ )
+
+ for idx, (encoder_layer, conv_layer) in enumerate(zip(self.layers, self.conv_layers)):
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ to_drop = False
+ if self.training:
+ dropout_probability = torch.rand([])
+ if dropout_probability < self.layerdrop: # skip the layer
+ to_drop = True
+
+ if to_drop:
+ layer_outputs = (None, None)
+ else:
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ encoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ (head_mask[idx] if head_mask is not None else None),
+ output_attentions,
+ )
+ if conv_layer is not None:
+ output = self._gradient_checkpointing_func(conv_layer, layer_outputs[0])
+ layer_outputs = (output,) + layer_outputs[1:]
+ else:
+ layer_outputs = encoder_layer(
+ hidden_states,
+ attention_mask,
+ layer_head_mask=(head_mask[idx] if head_mask is not None else None),
+ output_attentions=output_attentions,
+ )
+ if conv_layer is not None:
+ output = conv_layer(layer_outputs[0])
+ layer_outputs = (output,) + layer_outputs[1:]
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_attentions = all_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ encoder_states = encoder_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
+ )
+
+
+class InformerDecoder(TimeSeriesTransformerDecoder):
+ def __init__(self, config: InformerConfig):
+ super().__init__(config)
+ self.dropout = config.dropout
+ self.layerdrop = config.decoder_layerdrop
+ if config.prediction_length is None:
+ raise ValueError("The `prediction_length` config needs to be specified.")
+
+ self.value_embedding = InformerValueEmbedding(feature_size=config.feature_size, d_model=config.d_model)
+ self.embed_positions = InformerSinusoidalPositionalEmbedding(
+ config.context_length + config.prediction_length, config.d_model
+ )
+ self.layers = nn.ModuleList([InformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
+ self.layernorm_embedding = nn.LayerNorm(config.d_model)
+
+ self.gradient_checkpointing = False
+ # Initialize weights and apply final processing
+ self.post_init()
+
+
+class InformerModel(TimeSeriesTransformerModel, nn.Module):
+ def __init__(self, config: InformerConfig):
+ nn.Module().__init__(config)
+
+ if config.scaling == "mean" or config.scaling is True:
+ self.scaler = InformerMeanScaler(config)
+ elif config.scaling == "std":
+ self.scaler = InformerStdScaler(config)
+ else:
+ self.scaler = InformerNOPScaler(config)
+
+ if config.num_static_categorical_features > 0:
+ self.embedder = InformerFeatureEmbedder(
+ cardinalities=config.cardinality,
+ embedding_dims=config.embedding_dimension,
+ )
+
+ # transformer encoder-decoder and mask initializer
+ self.encoder = InformerEncoder(config)
+ self.decoder = InformerDecoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ def forward(self, **super_kwargs):
+ r"""
+ past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`):
+ Past values of the time series, that serve as context in order to predict the future. The sequence size of
+ this tensor must be larger than the `context_length` of the model, since the model will use the larger size
+ to construct lag features, i.e. additional values from the past which are added in order to serve as "extra
+ context".
+
+ The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no
+ `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest
+ look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of
+ the past.
+
+ The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as
+ `static_categorical_features`, `static_real_features`, `past_time_features` and lags).
+
+ Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`.
+
+ For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of
+ variates in the time series per time step.
+ past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`):
+ Required time features, which the model internally will add to `past_values`. These could be things like
+ "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These
+ could also be so-called "age" features, which basically help the model know "at which point in life" a
+ time-series is. Age features have small values for distant past time steps and increase monotonically the
+ more we approach the current time step. Holiday features are also a good example of time features.
+
+ These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where
+ the position encodings are learned from scratch internally as parameters of the model, the Time Series
+ Transformer requires to provide additional time features. The Time Series Transformer only learns
+ additional embeddings for `static_categorical_features`.
+
+ Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features
+ must but known at prediction time.
+
+ The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`.
+ past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*):
+ Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in
+ `[0, 1]`:
+
+ - 1 for values that are **observed**,
+ - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
+ static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*):
+ Optional static categorical features for which the model will learn an embedding, which it will add to the
+ values of the time series.
+
+ Static categorical features are features which have the same value for all time steps (static over time).
+
+ A typical example of a static categorical feature is a time series ID.
+ static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*):
+ Optional static real features which the model will add to the values of the time series.
+
+ Static real features are features which have the same value for all time steps (static over time).
+
+ A typical example of a static real feature is promotion information.
+ future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*):
+ Future values of the time series, that serve as labels for the model. The `future_values` is what the
+ Transformer needs during training to learn to output, given the `past_values`.
+
+ The sequence length here is equal to `prediction_length`.
+
+ See the demo notebook and code snippets for details.
+
+ Optionally, during training any missing values need to be replaced with zeros and indicated via the
+ `future_observed_mask`.
+
+ For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of
+ variates in the time series per time step.
+ future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`):
+ Required time features for the prediction window, which the model internally will add to `future_values`.
+ These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as
+ Fourier features). These could also be so-called "age" features, which basically help the model know "at
+ which point in life" a time-series is. Age features have small values for distant past time steps and
+ increase monotonically the more we approach the current time step. Holiday features are also a good example
+ of time features.
+
+ These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where
+ the position encodings are learned from scratch internally as parameters of the model, the Time Series
+ Transformer requires to provide additional time features. The Time Series Transformer only learns
+ additional embeddings for `static_categorical_features`.
+
+ Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features
+ must but known at prediction time.
+
+ The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`.
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+ Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+
+ Examples:
+
+ ```python
+ >>> from huggingface_hub import hf_hub_download
+ >>> import torch
+ >>> from transformers import InformerModel
+
+ >>> file = hf_hub_download(
+ ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset"
+ ... )
+ >>> batch = torch.load(file)
+
+ >>> model = InformerModel.from_pretrained("huggingface/informer-tourism-monthly")
+
+ >>> # during training, one provides both past and future values
+ >>> # as well as possible additional features
+ >>> outputs = model(
+ ... past_values=batch["past_values"],
+ ... past_time_features=batch["past_time_features"],
+ ... past_observed_mask=batch["past_observed_mask"],
+ ... static_categorical_features=batch["static_categorical_features"],
+ ... static_real_features=batch["static_real_features"],
+ ... future_values=batch["future_values"],
+ ... future_time_features=batch["future_time_features"],
+ ... )
+
+ >>> last_hidden_state = outputs.last_hidden_state
+ ```"""
+ super().forward(**super_kwargs)
+
+
+class InformerForPrediction(TimeSeriesTransformerForPrediction, nn.Module):
+ def __init__(self, config: InformerConfig):
+ nn.Module().__init__(config)
+
+ self.model = InformerModel(config)
+ if config.distribution_output == "student_t":
+ self.distribution_output = StudentTOutput(dim=config.input_size)
+ elif config.distribution_output == "normal":
+ self.distribution_output = NormalOutput(dim=config.input_size)
+ elif config.distribution_output == "negative_binomial":
+ self.distribution_output = NegativeBinomialOutput(dim=config.input_size)
+ else:
+ raise ValueError(f"Unknown distribution output {config.distribution_output}")
+
+ self.parameter_projection = self.distribution_output.get_parameter_projection(self.model.config.d_model)
+ self.target_shape = self.distribution_output.event_shape
+
+ if config.loss == "nll":
+ self.loss = nll
+ else:
+ raise ValueError(f"Unknown loss function {config.loss}")
+
+ # Initialize weights of distribution_output and apply final processing
+ self.post_init()
+
+ @auto_docstring
+ def forward(self, **super_kwargs):
+ r"""
+ past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`):
+ Past values of the time series, that serve as context in order to predict the future. The sequence size of
+ this tensor must be larger than the `context_length` of the model, since the model will use the larger size
+ to construct lag features, i.e. additional values from the past which are added in order to serve as "extra
+ context".
+
+ The `sequence_length` here is equal to `config.context_length` + `max(config.lags_sequence)`, which if no
+ `lags_sequence` is configured, is equal to `config.context_length` + 7 (as by default, the largest
+ look-back index in `config.lags_sequence` is 7). The property `_past_length` returns the actual length of
+ the past.
+
+ The `past_values` is what the Transformer encoder gets as input (with optional additional features, such as
+ `static_categorical_features`, `static_real_features`, `past_time_features` and lags).
+
+ Optionally, missing values need to be replaced with zeros and indicated via the `past_observed_mask`.
+
+ For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of
+ variates in the time series per time step.
+ past_time_features (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_features)`):
+ Required time features, which the model internally will add to `past_values`. These could be things like
+ "month of year", "day of the month", etc. encoded as vectors (for instance as Fourier features). These
+ could also be so-called "age" features, which basically help the model know "at which point in life" a
+ time-series is. Age features have small values for distant past time steps and increase monotonically the
+ more we approach the current time step. Holiday features are also a good example of time features.
+
+ These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where
+ the position encodings are learned from scratch internally as parameters of the model, the Time Series
+ Transformer requires to provide additional time features. The Time Series Transformer only learns
+ additional embeddings for `static_categorical_features`.
+
+ Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features
+ must but known at prediction time.
+
+ The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`.
+ past_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*):
+ Boolean mask to indicate which `past_values` were observed and which were missing. Mask values selected in
+ `[0, 1]`:
+
+ - 1 for values that are **observed**,
+ - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
+ static_categorical_features (`torch.LongTensor` of shape `(batch_size, number of static categorical features)`, *optional*):
+ Optional static categorical features for which the model will learn an embedding, which it will add to the
+ values of the time series.
+
+ Static categorical features are features which have the same value for all time steps (static over time).
+
+ A typical example of a static categorical feature is a time series ID.
+ static_real_features (`torch.FloatTensor` of shape `(batch_size, number of static real features)`, *optional*):
+ Optional static real features which the model will add to the values of the time series.
+
+ Static real features are features which have the same value for all time steps (static over time).
+
+ A typical example of a static real feature is promotion information.
+ future_values (`torch.FloatTensor` of shape `(batch_size, prediction_length)` or `(batch_size, prediction_length, input_size)`, *optional*):
+ Future values of the time series, that serve as labels for the model. The `future_values` is what the
+ Transformer needs during training to learn to output, given the `past_values`.
+
+ The sequence length here is equal to `prediction_length`.
+
+ See the demo notebook and code snippets for details.
+
+ Optionally, during training any missing values need to be replaced with zeros and indicated via the
+ `future_observed_mask`.
+
+ For multivariate time series, the `input_size` > 1 dimension is required and corresponds to the number of
+ variates in the time series per time step.
+ future_time_features (`torch.FloatTensor` of shape `(batch_size, prediction_length, num_features)`):
+ Required time features for the prediction window, which the model internally will add to `future_values`.
+ These could be things like "month of year", "day of the month", etc. encoded as vectors (for instance as
+ Fourier features). These could also be so-called "age" features, which basically help the model know "at
+ which point in life" a time-series is. Age features have small values for distant past time steps and
+ increase monotonically the more we approach the current time step. Holiday features are also a good example
+ of time features.
+
+ These features serve as the "positional encodings" of the inputs. So contrary to a model like BERT, where
+ the position encodings are learned from scratch internally as parameters of the model, the Time Series
+ Transformer requires to provide additional time features. The Time Series Transformer only learns
+ additional embeddings for `static_categorical_features`.
+
+ Additional dynamic real covariates can be concatenated to this tensor, with the caveat that these features
+ must but known at prediction time.
+
+ The `num_features` here is equal to `config.`num_time_features` + `config.num_dynamic_real_features`.
+ future_observed_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`, *optional*):
+ Boolean mask to indicate which `future_values` were observed and which were missing. Mask values selected
+ in `[0, 1]`:
+
+ - 1 for values that are **observed**,
+ - 0 for values that are **missing** (i.e. NaNs that were replaced by zeros).
+
+ This mask is used to filter out missing values for the final loss calculation.
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
+ Tuple consists of `last_hidden_state`, `hidden_states` (*optional*) and `attentions` (*optional*)
+ `last_hidden_state` of shape `(batch_size, sequence_length, hidden_size)` (*optional*) is a sequence of
+ hidden-states at the output of the last layer of the encoder. Used in the cross-attention of the decoder.
+
+ Examples:
+
+ ```python
+ >>> from huggingface_hub import hf_hub_download
+ >>> import torch
+ >>> from transformers import InformerForPrediction
+
+ >>> file = hf_hub_download(
+ ... repo_id="hf-internal-testing/tourism-monthly-batch", filename="train-batch.pt", repo_type="dataset"
+ ... )
+ >>> batch = torch.load(file)
+
+ >>> model = InformerForPrediction.from_pretrained(
+ ... "huggingface/informer-tourism-monthly"
+ ... )
+
+ >>> # during training, one provides both past and future values
+ >>> # as well as possible additional features
+ >>> outputs = model(
+ ... past_values=batch["past_values"],
+ ... past_time_features=batch["past_time_features"],
+ ... past_observed_mask=batch["past_observed_mask"],
+ ... static_categorical_features=batch["static_categorical_features"],
+ ... static_real_features=batch["static_real_features"],
+ ... future_values=batch["future_values"],
+ ... future_time_features=batch["future_time_features"],
+ ... )
+
+ >>> loss = outputs.loss
+ >>> loss.backward()
+
+ >>> # during inference, one only provides past values
+ >>> # as well as possible additional features
+ >>> # the model autoregressively generates future values
+ >>> outputs = model.generate(
+ ... past_values=batch["past_values"],
+ ... past_time_features=batch["past_time_features"],
+ ... past_observed_mask=batch["past_observed_mask"],
+ ... static_categorical_features=batch["static_categorical_features"],
+ ... static_real_features=batch["static_real_features"],
+ ... future_time_features=batch["future_time_features"],
+ ... )
+
+ >>> mean_prediction = outputs.sequences.mean(dim=1)
+ ```"""
+ super().forward(**super_kwargs)
+
+
+__all__ = ["InformerForPrediction", "InformerModel", "InformerPreTrainedModel"]
diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py
index 757a393a0bc0..55ecad415233 100755
--- a/src/transformers/models/m2m_100/modeling_m2m_100.py
+++ b/src/transformers/models/m2m_100/modeling_m2m_100.py
@@ -15,7 +15,7 @@
"""PyTorch M2M100 model."""
import math
-from typing import List, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import torch
from torch import nn
@@ -31,31 +31,23 @@
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
-from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_flash_attention_utils import (
+ FlashAttentionKwargs,
+)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
-from ...modeling_utils import PreTrainedModel
-from ...utils import (
- auto_docstring,
- is_torch_flex_attn_available,
- is_torchdynamo_compiling,
- logging,
-)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
from .configuration_m2m_100 import M2M100Config
if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask
-
- from ...integrations.flex_attention import make_flex_block_causal_mask
-
-
-if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -184,6 +176,37 @@ def create_position_ids_from_inputs_embeds(self, inputs_embeds, past_key_values_
return position_ids.unsqueeze(0).expand(input_shape).contiguous() + past_key_values_length
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->M2M100
class M2M100Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -236,152 +259,25 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
-
- if past_key_value is not None:
- if isinstance(past_key_value, EncoderDecoderCache):
- is_updated = past_key_value.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_value = past_key_value.cross_attention_cache
- else:
- curr_past_key_value = past_key_value.self_attention_cache
- else:
- curr_past_key_value = past_key_value
-
- current_states = key_value_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_value is not None and is_updated:
- # reuse k,v, cross_attentions
- key_states = curr_past_key_value.key_cache[self.layer_idx]
- value_states = curr_past_key_value.value_cache[self.layer_idx]
- else:
- key_states = self.k_proj(current_states)
- value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- if past_key_value is not None:
- # save all key/value_states to cache to be re-used for fast auto-regressive generation
- cache_position = cache_position if not is_cross_attention else None
- key_states, value_states = curr_past_key_value.update(
- key_states, value_states, self.layer_idx, {"cache_position": cache_position}
- )
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention:
- past_key_value.is_updated[self.layer_idx] = True
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped, past_key_value
-
-
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->M2M100
-class M2M100FlashAttention2(M2M100Attention):
- """
- M2M100 flash attention module. This module inherits from `M2M100Attention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Cache] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- cache_position: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # M2M100FlashAttention2 attention does not support output_attentions
- if output_attentions:
- raise ValueError(
- "M2M100SdpaAttention2 attention does not support `output_attentions`. "
- "Use the argument `attn_implementation='eager'` when loading the model."
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- bsz, q_len, _ = hidden_states.size()
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
@@ -402,8 +298,8 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -415,157 +311,27 @@ def forward(
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
- # to be able to avoid many of these transpose/reshape/view.
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (LlamaRMSNorm handles it correctly)
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = _flash_attention_forward(
+ attn_output, attn_weights = attention_interface(
+ self,
query_states,
key_states,
value_states,
attention_mask,
- q_len,
- dropout=self.dropout if self.training else 0.0,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1)
- attn_output = self.out_proj(attn_output)
-
- return attn_output, None, past_key_value
-
-
-# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->M2M100
-class M2M100SdpaAttention(M2M100Attention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Cache] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- cache_position: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "M2M100Model is using M2M100SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- if past_key_value is not None:
- if isinstance(past_key_value, EncoderDecoderCache):
- is_updated = past_key_value.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_value = past_key_value.cross_attention_cache
- else:
- curr_past_key_value = past_key_value.self_attention_cache
- else:
- curr_past_key_value = past_key_value
-
- current_states = key_value_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_value is not None and is_updated:
- # reuse k,v, cross_attentions
- key_states = curr_past_key_value.key_cache[self.layer_idx]
- value_states = curr_past_key_value.value_cache[self.layer_idx]
- else:
- key_states = self.k_proj(current_states)
- value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- if past_key_value is not None:
- # save all key/value_states to cache to be re-used for fast auto-regressive generation
- cache_position = cache_position if not is_cross_attention else None
- key_states, value_states = curr_past_key_value.update(
- key_states, value_states, self.layer_idx, {"cache_position": cache_position}
- )
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention:
- past_key_value.is_updated[self.layer_idx] = True
-
- causal_mask = None
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
-
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query_states.device.type == "cuda" and causal_mask is not None:
- query_states = query_states.contiguous()
- key_states = key_states.contiguous()
- value_states = value_states.contiguous()
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
- is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
-
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attn_mask=causal_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, None, past_key_value
+ return attn_output, attn_weights, past_key_value
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100, MBART->M2M100
@@ -574,7 +340,7 @@ def __init__(self, config: M2M100Config):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = M2M100Attention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -637,20 +403,13 @@ def forward(
return outputs
-M2M100_ATTENTION_CLASSES = {
- "eager": M2M100Attention,
- "flash_attention_2": M2M100FlashAttention2,
- "sdpa": M2M100SdpaAttention,
-}
-
-
# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100
class M2M100DecoderLayer(nn.Module):
def __init__(self, config: M2M100Config, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = M2M100Attention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -664,7 +423,7 @@ def __init__(self, config: M2M100Config, layer_idx: Optional[int] = None):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = M2M100_ATTENTION_CLASSES[config._attn_implementation](
+ self.encoder_attn = M2M100Attention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -732,7 +491,6 @@ def forward(
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
- # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
hidden_states, cross_attn_weights, past_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
@@ -772,6 +530,8 @@ class M2M100PreTrainedModel(PreTrainedModel):
_no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"]
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
_supports_cache_class = True
# Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model
_supports_static_cache = False
@@ -790,24 +550,56 @@ def _init_weights(self, module):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
def _update_causal_mask(
self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
- output_attentions: bool = False,
):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
return attention_mask
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
@@ -815,7 +607,7 @@ def _update_causal_mask(
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@@ -849,7 +641,6 @@ def _update_causal_mask(
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
@@ -915,6 +706,42 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
class M2M100Encoder(M2M100PreTrainedModel):
"""
@@ -951,8 +778,6 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] =
)
self.layers = nn.ModuleList([M2M100EncoderLayer(config) for _ in range(config.encoder_layers)])
self.layer_norm = nn.LayerNorm(config.d_model)
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
- self._use_sdpa = config._attn_implementation == "sdpa"
self.gradient_checkpointing = False
# Initialize weights and apply final processing
@@ -1031,18 +856,10 @@ def forward(
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- # expand attention_mask
- if attention_mask is not None:
- if self._use_flash_attention_2:
- attention_mask = attention_mask if 0 in attention_mask else None
- elif self._use_sdpa and head_mask is None and not output_attentions:
- # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
- # the manual implementation that requires a 4D causal mask in all cases.
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
- else:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@@ -1133,8 +950,6 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] =
self.padding_idx,
)
self.layers = nn.ModuleList([M2M100DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
- self._use_sdpa = config._attn_implementation == "sdpa"
self.layer_norm = nn.LayerNorm(config.d_model)
self.gradient_checkpointing = False
@@ -1232,23 +1047,27 @@ def forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-
- if input_ids is not None:
- input_ids = input_ids.view(-1, input_ids.shape[-1])
+ elif input_ids is not None:
+ input_shape = input_ids.size()
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
# initialize `past_key_values`
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
@@ -1277,35 +1096,19 @@ def forward(
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
- causal_mask = self._update_causal_mask(
+
+ attention_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
self_attn_cache,
- output_attentions,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self._use_flash_attention_2:
- encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
- elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
- # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
- # the manual implementation that requires a 4D causal mask in all cases.
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
- encoder_attention_mask,
- inputs_embeds.dtype,
- tgt_len=seq_length,
- )
- else:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask,
- inputs_embeds.dtype,
- tgt_len=seq_length,
- )
-
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
# embed positions
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
positions = positions.to(inputs_embeds.device)
@@ -1313,13 +1116,6 @@ def forward(
hidden_states = inputs_embeds + positions
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
@@ -1351,7 +1147,7 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
- causal_mask,
+ attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
@@ -1364,7 +1160,7 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
- attention_mask=causal_mask,
+ attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py
index a9a3fd353ecf..016cb865f83d 100755
--- a/src/transformers/models/marian/modeling_marian.py
+++ b/src/transformers/models/marian/modeling_marian.py
@@ -16,7 +16,7 @@
import copy
import math
-from typing import List, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -30,7 +30,9 @@
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -38,7 +40,8 @@
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
is_torch_flex_attn_available,
@@ -108,6 +111,37 @@ def forward(
return super().forward(position_ids)
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Marian
class MarianAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -160,17 +194,25 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
@@ -191,8 +233,8 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -204,66 +246,27 @@ def forward(
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
+ return attn_output, attn_weights, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN
@@ -272,7 +275,7 @@ def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = MarianAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -338,16 +341,13 @@ def forward(
return outputs
-MARIAN_ATTENTION_CLASSES = {"eager": MarianAttention}
-
-
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN
class MarianDecoderLayer(nn.Module):
def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = MarianAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -361,7 +361,7 @@ def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = MARIAN_ATTENTION_CLASSES[config._attn_implementation](
+ self.encoder_attn = MarianAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -435,6 +435,7 @@ def forward(
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
+ cache_position=cache_position,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
@@ -465,6 +466,10 @@ class MarianPreTrainedModel(PreTrainedModel):
config_class = MarianConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
_supports_cache_class = True
_supports_static_cache = True
@@ -484,24 +489,67 @@ def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalP
module.weight.data.fill_(1.0)
module.bias.data.zero_()
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ @property
+ def dummy_inputs(self):
+ pad_token = self.config.pad_token_id
+ input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
+ dummy_inputs = {
+ "attention_mask": input_ids.ne(pad_token),
+ "input_ids": input_ids,
+ "decoder_input_ids": input_ids,
+ }
+ return dummy_inputs
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
def _update_causal_mask(
self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
- output_attentions: bool = False,
):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
return attention_mask
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
@@ -509,7 +557,7 @@ def _update_causal_mask(
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@@ -543,7 +591,6 @@ def _update_causal_mask(
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
@@ -609,16 +656,41 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
- @property
- def dummy_inputs(self):
- pad_token = self.config.pad_token_id
- input_ids = torch.tensor([[0, 6, 10, 4, 2], [0, 8, 12, 2, pad_token]], device=self.device)
- dummy_inputs = {
- "attention_mask": input_ids.ne(pad_token),
- "input_ids": input_ids,
- "decoder_input_ids": input_ids,
- }
- return dummy_inputs
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
class MarianEncoder(MarianPreTrainedModel):
@@ -734,10 +806,10 @@ def forward(
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- # expand attention_mask
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@@ -929,12 +1001,18 @@ def forward(
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-
- if input_ids is not None:
- input_ids = input_ids.view(-1, input_ids.shape[-1])
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
+ inputs_embeds = self.embed_tokens(input)
# Important to apply outside of the above `if`, in case user passes `embeds`
inputs_embeds = inputs_embeds * self.embed_scale
@@ -967,22 +1045,19 @@ def forward(
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
+
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
self_attn_cache,
- output_attentions,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask,
- inputs_embeds.dtype,
- tgt_len=seq_length,
- )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
# embed positions
position_ids = self.embed_positions(
diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py
index 3fbb3e8b5be4..bdf352a1f646 100755
--- a/src/transformers/models/mbart/modeling_mbart.py
+++ b/src/transformers/models/mbart/modeling_mbart.py
@@ -16,7 +16,7 @@
import copy
import math
-from typing import List, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
@@ -31,7 +31,9 @@
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
-from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_flash_attention_utils import (
+ FlashAttentionKwargs,
+)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -41,7 +43,8 @@
Seq2SeqQuestionAnsweringModelOutput,
Seq2SeqSequenceClassifierOutput,
)
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
is_torch_flex_attn_available,
@@ -52,13 +55,7 @@
if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask
-
- from ...integrations.flex_attention import make_flex_block_causal_mask
-
-
-if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -124,6 +121,37 @@ def forward(self, input_ids: torch.Tensor):
return super().forward(input_ids) * self.embed_scale
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MBart
class MBartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -176,152 +204,25 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
- if past_key_value is not None:
- if isinstance(past_key_value, EncoderDecoderCache):
- is_updated = past_key_value.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_value = past_key_value.cross_attention_cache
- else:
- curr_past_key_value = past_key_value.self_attention_cache
- else:
- curr_past_key_value = past_key_value
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
- current_states = key_value_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_value is not None and is_updated:
- # reuse k,v, cross_attentions
- key_states = curr_past_key_value.key_cache[self.layer_idx]
- value_states = curr_past_key_value.value_cache[self.layer_idx]
- else:
- key_states = self.k_proj(current_states)
- value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- if past_key_value is not None:
- # save all key/value_states to cache to be re-used for fast auto-regressive generation
- cache_position = cache_position if not is_cross_attention else None
- key_states, value_states = curr_past_key_value.update(
- key_states, value_states, self.layer_idx, {"cache_position": cache_position}
- )
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention:
- past_key_value.is_updated[self.layer_idx] = True
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped, past_key_value
-
-
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MBart
-class MBartFlashAttention2(MBartAttention):
- """
- MBart flash attention module. This module inherits from `MBartAttention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Cache] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- cache_position: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # MBartFlashAttention2 attention does not support output_attentions
- if output_attentions:
- raise ValueError(
- "MBartSdpaAttention2 attention does not support `output_attentions`. "
- "Use the argument `attn_implementation='eager'` when loading the model."
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
- bsz, q_len, _ = hidden_states.size()
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
@@ -342,8 +243,8 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -355,164 +256,27 @@ def forward(
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
- # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache
- # to be able to avoid many of these transpose/reshape/view.
- key_states = key_states.transpose(1, 2)
- value_states = value_states.transpose(1, 2)
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (LlamaRMSNorm handles it correctly)
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- attn_output = _flash_attention_forward(
+ attn_output, attn_weights = attention_interface(
+ self,
query_states,
key_states,
value_states,
attention_mask,
- q_len,
- dropout=self.dropout if self.training else 0.0,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1)
- attn_output = self.out_proj(attn_output)
-
- return attn_output, None, past_key_value
-
-
-# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MBart
-class MBartSdpaAttention(MBartAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Cache] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- cache_position: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "MBartModel is using MBartSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- if past_key_value is not None:
- if isinstance(past_key_value, EncoderDecoderCache):
- is_updated = past_key_value.is_updated.get(self.layer_idx)
- if is_cross_attention:
- # after the first generated id, we can subsequently re-use all key/value_states from cache
- curr_past_key_value = past_key_value.cross_attention_cache
- else:
- curr_past_key_value = past_key_value.self_attention_cache
- else:
- curr_past_key_value = past_key_value
-
- current_states = key_value_states if is_cross_attention else hidden_states
- if is_cross_attention and past_key_value is not None and is_updated:
- # reuse k,v, cross_attentions
- key_states = curr_past_key_value.key_cache[self.layer_idx]
- value_states = curr_past_key_value.value_cache[self.layer_idx]
- else:
- key_states = self.k_proj(current_states)
- value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- if past_key_value is not None:
- # save all key/value_states to cache to be re-used for fast auto-regressive generation
- cache_position = cache_position if not is_cross_attention else None
- key_states, value_states = curr_past_key_value.update(
- key_states, value_states, self.layer_idx, {"cache_position": cache_position}
- )
- # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
- if is_cross_attention:
- past_key_value.is_updated[self.layer_idx] = True
-
- causal_mask = None
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
-
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query_states.device.type == "cuda" and causal_mask is not None:
- query_states = query_states.contiguous()
- key_states = key_states.contiguous()
- value_states = value_states.contiguous()
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
- is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
-
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attn_mask=causal_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, None, past_key_value
-
-
-MBART_ATTENTION_CLASSES = {
- "eager": MBartAttention,
- "sdpa": MBartSdpaAttention,
- "flash_attention_2": MBartFlashAttention2,
-}
+ return attn_output, attn_weights, past_key_value
class MBartEncoderLayer(nn.Module):
@@ -520,7 +284,7 @@ def __init__(self, config: MBartConfig):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = MBartAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -588,7 +352,7 @@ def __init__(self, config: MBartConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = MBartAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -602,7 +366,7 @@ def __init__(self, config: MBartConfig, layer_idx: Optional[int] = None):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = MBART_ATTENTION_CLASSES[config._attn_implementation](
+ self.encoder_attn = MBartAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -670,7 +434,6 @@ def forward(
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
- # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
hidden_states, cross_attn_weights, past_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
@@ -735,6 +498,8 @@ class MBartPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"]
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
_supports_cache_class = True
_supports_static_cache = True
@@ -762,24 +527,56 @@ def dummy_inputs(self):
}
return dummy_inputs
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
def _update_causal_mask(
self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
- output_attentions: bool = False,
):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
return attention_mask
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
@@ -787,7 +584,7 @@ def _update_causal_mask(
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@@ -821,7 +618,6 @@ def _update_causal_mask(
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
@@ -887,6 +683,42 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
class MBartEncoder(MBartPreTrainedModel):
"""
@@ -1007,18 +839,10 @@ def forward(
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- # expand attention_mask
- if attention_mask is not None:
- if self.config._attn_implementation == "flash_attention_2":
- attention_mask = attention_mask if 0 in attention_mask else None
- elif self.config._attn_implementation == "sdpa" and head_mask is None and not output_attentions:
- # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
- # the manual implementation that requires a 4D causal mask in all cases.
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
- else:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@@ -1214,12 +1038,18 @@ def forward(
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-
- if input_ids is not None:
- input_ids = input_ids.view(-1, input_ids.shape[-1])
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
+ inputs_embeds = self.embed_tokens(input)
if self.gradient_checkpointing and self.training:
if use_cache:
@@ -1256,32 +1086,19 @@ def forward(
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
+
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
self_attn_cache,
- output_attentions,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self.config._attn_implementation == "flash_attention_2":
- encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
- elif self.config._attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
- # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
- # the manual implementation that requires a 4D causal mask in all cases.
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
- encoder_attention_mask,
- inputs_embeds.dtype,
- tgt_len=seq_length,
- )
- else:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length
- )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
# embed positions
position_ids = self.embed_positions(input, past_key_values_length, position_ids=cache_position)
diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py
index 00a0afd62ad5..a0e21f586cfc 100644
--- a/src/transformers/models/musicgen/modeling_musicgen.py
+++ b/src/transformers/models/musicgen/modeling_musicgen.py
@@ -19,7 +19,7 @@
import math
import random
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -40,7 +40,9 @@
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
-from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_flash_attention_utils import (
+ FlashAttentionKwargs,
+)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -48,15 +50,17 @@
ModelOutput,
Seq2SeqLMOutput,
)
-from ...modeling_utils import PreTrainedModel
-from ...utils import auto_docstring, logging
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_auto import AutoModel
from .configuration_musicgen import MusicgenConfig, MusicgenDecoderConfig
-if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
@@ -146,7 +150,38 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0):
return self.weights.index_select(0, position_ids.view(-1)).detach()
-# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Musicgen
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Musicgen
class MusicgenAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -181,9 +216,6 @@ def __init__(
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -192,6 +224,9 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -199,285 +234,16 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped, past_key_value
-
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
-# Copied from transformers.models.hubert.modeling_hubert.HubertFlashAttention2 with Hubert->Musicgen
-class MusicgenFlashAttention2(MusicgenAttention):
- """
- Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
-
- def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, q_len, _ = hidden_states.size()
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0].transpose(1, 2)
- value_states = past_key_value[1].transpose(1, 2)
- elif is_cross_attention:
- # cross_attentions
- key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
- value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
- else:
- # self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (LlamaRMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=self.dropout if self.training else 0.0,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1)
- attn_output = self.out_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class MusicgenSdpaAttention(MusicgenAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
-
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "MusicgenModel is using MusicgenSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention"
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- )
-
- if (
- attention_mask is not None
- and (attention_mask.mean(dim=[1, 2, 3]) <= torch.finfo(attention_mask.dtype).min).any()
- ):
- logger.warning_once(
- '`torch.nn.functional.scaled_dot_product_attention` does not support having an empty attention mask. Falling back to the manual attention implementation. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- "Note that this probably happens because `guidance_scale>1` or because you used `get_unconditional_inputs`. See https://github.com/huggingface/transformers/issues/31189 for more information."
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -492,18 +258,18 @@ def forward(
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -515,46 +281,27 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
- query_states = self._shape(query_states, tgt_len, bsz)
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
- is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
-
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
- attn_output = torch.nn.functional.scaled_dot_product_attention(
+ attn_output, attn_weights = attention_interface(
+ self,
query_states,
key_states,
value_states,
- attn_mask=attention_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
)
- if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, None, past_key_value
-
-
-MUSICGEN_ATTENTION_CLASSES = {
- "eager": MusicgenAttention,
- "sdpa": MusicgenSdpaAttention,
- "flash_attention_2": MusicgenFlashAttention2,
-}
+ return attn_output, attn_weights, past_key_value
class MusicgenDecoderLayer(nn.Module):
@@ -562,7 +309,7 @@ def __init__(self, config: MusicgenDecoderConfig):
super().__init__()
self.embed_dim = config.hidden_size
- self.self_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = MusicgenAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
@@ -576,7 +323,7 @@ def __init__(self, config: MusicgenDecoderConfig):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = MUSICGEN_ATTENTION_CLASSES[config._attn_implementation](
+ self.encoder_attn = MusicgenAttention(
self.embed_dim,
config.num_attention_heads,
dropout=config.attention_dropout,
@@ -590,6 +337,7 @@ def __init__(self, config: MusicgenDecoderConfig):
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
# copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward
+ # TODO: change to new cache class
def forward(
self,
hidden_states: torch.Tensor,
@@ -688,6 +436,8 @@ class MusicgenPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MusicgenDecoderLayer", "MusicgenAttention"]
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # compilation errors occurr atm
+ _supports_flex_attn = False
def _init_weights(self, module):
std = self.config.initializer_factor
@@ -819,40 +569,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = sum([self.embed_tokens[codebook](input[:, codebook]) for codebook in range(num_codebooks)])
- if self.attn_implementation == "flash_attention_2":
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- elif self.attn_implementation == "sdpa" and head_mask is None and not output_attentions:
- # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
- # the manual implementation that requires a 4D causal mask in all cases.
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
- attention_mask,
- input_shape,
- inputs_embeds,
- past_key_values_length,
- )
- else:
- attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
- )
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self.attn_implementation == "flash_attention_2":
- encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
- elif self.attn_implementation == "sdpa" and cross_attn_head_mask is None and not output_attentions:
- # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
- # the manual implementation that requires a 4D causal mask in all cases.
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
- encoder_attention_mask,
- inputs_embeds.dtype,
- tgt_len=input_shape[-1],
- )
- else:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
- )
+ attention_mask = self._update_causal_mask(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
# embed positions
positions = self.embed_positions(input, past_key_values_length)
@@ -951,6 +679,80 @@ def forward(
cross_attentions=all_cross_attentions,
)
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_shape),
+ device=inputs_embeds.device,
+ )
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ return attention_mask
+
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
@auto_docstring
class MusicgenModel(MusicgenPreTrainedModel):
@@ -1559,6 +1361,8 @@ class MusicgenForConditionalGeneration(PreTrainedModel, GenerationMixin):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # compilation errors occurr atm
+ _supports_flex_attn = False
def __init__(
self,
diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
index c5328d064769..3312ad33cdb1 100644
--- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
+++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py
@@ -19,7 +19,7 @@
import math
import random
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -34,18 +34,25 @@
LogitsProcessorList,
StoppingCriteriaList,
)
-from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa
-from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from ...modeling_flash_attention_utils import (
+ FlashAttentionKwargs,
+)
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
-from ...modeling_utils import PreTrainedModel
-from ...utils import auto_docstring, logging
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from ..auto.configuration_auto import AutoConfig
from ..auto.modeling_auto import AutoModel, AutoModelForTextEncoding
from .configuration_musicgen_melody import MusicgenMelodyConfig, MusicgenMelodyDecoderConfig
-if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
if TYPE_CHECKING:
from ...generation.streamers import BaseStreamer
@@ -159,7 +166,38 @@ def forward(self, inputs_embeds: torch.Tensor, past_key_values_length: int = 0):
return self.weights.index_select(0, position_ids.view(-1)).detach()
-# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->MusicgenMelody
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->MusicgenMelody
class MusicgenMelodyAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -194,9 +232,6 @@ def __init__(
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -205,6 +240,9 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -212,269 +250,16 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped, past_key_value
-
-
-# Copied from transformers.models.hubert.modeling_hubert.HubertFlashAttention2 with Hubert->MusicgenMelody
-class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention):
- """
- MusicgenMelody flash attention module. This module inherits from `MusicgenMelodyAttention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
- def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, q_len, _ = hidden_states.size()
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0].transpose(1, 2)
- value_states = past_key_value[1].transpose(1, 2)
- elif is_cross_attention:
- # cross_attentions
- key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
- value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
- else:
- # self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (LlamaRMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=self.dropout if self.training else 0.0,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1)
- attn_output = self.out_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-# Copied from transformers.models.hubert.modeling_hubert.HubertSdpaAttention with Hubert->MusicgenMelody
-class MusicgenMelodySdpaAttention(MusicgenMelodyAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "MusicgenMelodyModel is using MusicgenMelodySdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -489,18 +274,18 @@ def forward(
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -512,46 +297,27 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
- query_states = self._shape(query_states, tgt_len, bsz)
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
- is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
-
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
- attn_output = torch.nn.functional.scaled_dot_product_attention(
+ attn_output, attn_weights = attention_interface(
+ self,
query_states,
key_states,
value_states,
- attn_mask=attention_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
)
- if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, None, past_key_value
-
-
-MUSICGEN_MELODY_ATTENTION_CLASSES = {
- "eager": MusicgenMelodyAttention,
- "sdpa": MusicgenMelodySdpaAttention,
- "flash_attention_2": MusicgenMelodyFlashAttention2,
-}
+ return attn_output, attn_weights, past_key_value
class MusicgenMelodyDecoderLayer(nn.Module):
@@ -559,7 +325,7 @@ def __init__(self, config: MusicgenMelodyDecoderConfig):
super().__init__()
self.embed_dim = config.hidden_size
- self.self_attn = MUSICGEN_MELODY_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = MusicgenMelodyAttention(
embed_dim=self.embed_dim,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
@@ -644,6 +410,8 @@ class MusicgenMelodyPreTrainedModel(PreTrainedModel):
_no_split_modules = ["MusicgenMelodyDecoderLayer", "MusicgenMelodyAttention"]
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # compilation errors occurr atm
+ _supports_flex_attn = False
def _init_weights(self, module):
std = self.config.initializer_factor
@@ -785,21 +553,12 @@ def forward(
input_shape = inputs_embeds.size()[:-1]
- if self.attn_implementation == "flash_attention_2":
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- elif self.attn_implementation == "sdpa" and not output_attentions:
- # output_attentions=True can not be supported when using SDPA, and we fall back on
- # the manual implementation that requires a 4D causal mask in all cases.
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
- attention_mask,
- input_shape,
- inputs_embeds,
- past_key_values_length,
- )
- else:
- attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
- )
+ attention_mask = self._update_causal_mask(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
# embed positions
positions = self.embed_positions(inputs_embeds, past_key_values_length)
@@ -881,6 +640,57 @@ def forward(
attentions=all_attentions,
)
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_shape),
+ device=inputs_embeds.device,
+ )
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ return attention_mask
+
+ # Ignore copy
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # MusicgenMelody doesn't apply cross attention, hence it's ignored here
+ # and only exists to not confuse any copy checks
+ pass
+
@auto_docstring
# Copied from transformers.models.musicgen.modeling_musicgen.MusicgenModel with MUSICGEN->MUSICGEN_MELODY,Musicgen->MusicgenMelody
@@ -1482,6 +1292,8 @@ class MusicgenMelodyForConditionalGeneration(PreTrainedModel, GenerationMixin):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # compilation errors occurr atm
+ _supports_flex_attn = False
def __init__(
self,
diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py
index 3729fd6e1815..6d7bd6c985d1 100644
--- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py
+++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py
@@ -15,7 +15,7 @@
"""PyTorch NLLB-MoE model."""
import math
-from typing import List, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -25,18 +25,29 @@
from ...generation import GenerationMixin
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
-from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
MoEModelOutput,
MoEModelOutputWithPastAndCrossAttentions,
Seq2SeqMoEModelOutput,
Seq2SeqMoEOutput,
)
-from ...modeling_utils import PreTrainedModel
-from ...utils import auto_docstring, logging
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from .configuration_nllb_moe import NllbMoeConfig
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
logger = logging.get_logger(__name__)
@@ -460,7 +471,38 @@ def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tens
return hidden_states, (router_probs, top_1_expert_index)
-# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->NllbMoe,key_value_states->encoder_hidden_states
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->NllbMoe,key_value_states->encoder_hidden_states
class NllbMoeAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -495,9 +537,6 @@ def __init__(
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -506,6 +545,9 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -513,10 +555,16 @@ def forward(
# for the decoder
is_cross_attention = encoder_hidden_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = encoder_hidden_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
# get key, value proj
# `past_key_value[0].shape[2] == encoder_hidden_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -531,18 +579,18 @@ def forward(
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
- key_states = self._shape(self.k_proj(encoder_hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(encoder_hidden_states), -1, bsz)
+ key_states = self.k_proj(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(encoder_hidden_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -554,69 +602,27 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
+ return attn_output, attn_weights, past_key_value
class NllbMoeEncoderLayer(nn.Module):
@@ -628,6 +634,7 @@ def __init__(self, config: NllbMoeConfig, is_sparse: bool = False):
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
+ config=config,
)
self.attn_dropout = nn.Dropout(config.dropout)
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
@@ -710,6 +717,7 @@ def __init__(self, config: NllbMoeConfig, is_sparse: bool = False):
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
is_decoder=True,
+ config=config,
)
self.dropout = config.dropout
self.activation_fn = ACT2FN[config.activation_function]
@@ -717,7 +725,11 @@ def __init__(self, config: NllbMoeConfig, is_sparse: bool = False):
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
self.cross_attention = NllbMoeAttention(
- self.embed_dim, config.decoder_attention_heads, config.attention_dropout, is_decoder=True
+ self.embed_dim,
+ config.decoder_attention_heads,
+ config.attention_dropout,
+ is_decoder=True,
+ config=config,
)
self.cross_attention_layer_norm = nn.LayerNorm(self.embed_dim)
if not self.is_sparse:
@@ -837,6 +849,12 @@ class NllbMoePreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["NllbMoeEncoderLayer", "NllbMoeDecoderLayer"]
+ # TODO: If anyone is up to it to make sure tests pass etc
+ # Flash attention has problems due to not preparing masks the same way as eager/sdpa
+ # SDPA has more flaky logits which requires more time to look into tests
+ _supports_flash_attn_2 = False
+ _supports_sdpa = False
+ _supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
@@ -975,10 +993,10 @@ def forward(
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- # expand attention_mask
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
encoder_states = () if output_hidden_states else None
all_router_probs = () if output_router_logits else None
@@ -1042,6 +1060,29 @@ def forward(
router_probs=all_router_probs,
)
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
class NllbMoeDecoder(NllbMoePreTrainedModel):
"""
@@ -1195,18 +1236,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
- # create causal mask
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- combined_attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
+ attention_mask = self._update_causal_mask(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
- )
# embed positions
positions = self.embed_positions(input_ids, inputs_embeds, past_key_values_length)
@@ -1264,7 +1305,7 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.forward,
hidden_states,
- combined_attention_mask,
+ attention_mask,
encoder_hidden_states,
encoder_attention_mask,
layer_head_mask,
@@ -1276,7 +1317,7 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
- attention_mask=combined_attention_mask,
+ attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=layer_head_mask,
@@ -1330,6 +1371,82 @@ def forward(
router_probs=all_router_probs,
)
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_shape),
+ device=inputs_embeds.device,
+ )
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ return attention_mask
+
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
@auto_docstring
class NllbMoeModel(NllbMoePreTrainedModel):
diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py
index c10a55b1b024..8f00e8900928 100644
--- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py
+++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py
@@ -16,7 +16,7 @@
import math
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import torch
import torch.nn as nn
@@ -24,9 +24,11 @@
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import ModelOutput
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
+from ...processing_utils import Unpack
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import auto_docstring, logging
-from ...utils.deprecation import deprecate_kwarg
from .configuration_patchtsmixer import PatchTSMixerConfig
@@ -235,7 +237,38 @@ def forward(self, inputs: torch.Tensor):
return out
-# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTSMixer
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->PatchTSMixer
class PatchTSMixerAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -248,7 +281,6 @@ def __init__(
bias: bool = True,
is_causal: bool = False,
config: Optional[PatchTSMixerConfig] = None,
- layer_idx: Optional[int] = None,
):
super().__init__()
self.embed_dim = embed_dim
@@ -265,23 +297,12 @@ def __init__(
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
- self.layer_idx = layer_idx
- if layer_idx is None and self.is_decoder:
- logger.warning_once(
- f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
- "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- # Ignore copy
- @deprecate_kwarg("key_value_states", version="4.55")
- @deprecate_kwarg("past_key_value", version="4.55")
- @deprecate_kwarg("cache_position", version="4.55")
def forward(
self,
hidden_states: torch.Tensor,
@@ -290,79 +311,84 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
- cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
- key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
+ # get query proj
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ # self_attention
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, None
+ return attn_output, attn_weights, past_key_value
class PatchMixerBlock(nn.Module):
@@ -395,6 +421,7 @@ def __init__(self, config: PatchTSMixerConfig):
embed_dim=config.d_model,
num_heads=config.self_attn_heads,
dropout=config.dropout,
+ config=config,
)
self.norm_attn = PatchTSMixerNormLayer(config)
diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py
index 57b69b6b5ebc..b85e8a66b254 100755
--- a/src/transformers/models/patchtst/modeling_patchtst.py
+++ b/src/transformers/models/patchtst/modeling_patchtst.py
@@ -16,14 +16,16 @@
import math
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import torch
from torch import nn
from ...activations import ACT2CLS
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
from ...utils import ModelOutput, auto_docstring, logging
from .configuration_patchtst import PatchTSTConfig
@@ -32,7 +34,38 @@
logger = logging.get_logger(__name__)
-# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->PatchTST
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->PatchTST
class PatchTSTAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -67,9 +100,6 @@ def __init__(
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -78,6 +108,9 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -85,10 +118,16 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -103,18 +142,18 @@ def forward(
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -126,69 +165,27 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
+ return attn_output, attn_weights, past_key_value
class PatchTSTBatchNorm(nn.Module):
@@ -461,6 +458,7 @@ def __init__(self, config: PatchTSTConfig):
embed_dim=config.d_model,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
+ config=config,
)
# Add & Norm of the sublayer 1
diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py
index 3f59a8c9186b..19166bd6091d 100755
--- a/src/transformers/models/pegasus/modeling_pegasus.py
+++ b/src/transformers/models/pegasus/modeling_pegasus.py
@@ -16,7 +16,7 @@
import copy
import math
-from typing import List, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -30,7 +30,9 @@
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -38,7 +40,8 @@
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
is_torch_flex_attn_available,
@@ -49,9 +52,7 @@
if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask
-
- from ...integrations.flex_attention import make_flex_block_causal_mask
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -109,6 +110,37 @@ def forward(
return super().forward(position_ids)
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Pegasus
class PegasusAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -161,17 +193,25 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
@@ -192,8 +232,8 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -205,69 +245,27 @@ def forward(
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
-
-
-PEGASUS_ATTENTION_CLASSES = {"eager": PegasusAttention}
+ return attn_output, attn_weights, past_key_value
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Pegasus, MBART->PEGASUS
@@ -276,7 +274,7 @@ def __init__(self, config: PegasusConfig):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = PegasusAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -345,7 +343,7 @@ def __init__(self, config: PegasusConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = PegasusAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -359,7 +357,7 @@ def __init__(self, config: PegasusConfig, layer_idx: Optional[int] = None):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = PEGASUS_ATTENTION_CLASSES[config._attn_implementation](
+ self.encoder_attn = PegasusAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -427,7 +425,6 @@ def forward(
residual = hidden_states
hidden_states = self.encoder_attn_layer_norm(hidden_states)
- # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple
hidden_states, cross_attn_weights, past_key_value = self.encoder_attn(
hidden_states=hidden_states,
key_value_states=encoder_hidden_states,
@@ -464,6 +461,12 @@ class PegasusPreTrainedModel(PreTrainedModel):
config_class = PegasusConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
+ _supports_cache_class = True
+ _supports_static_cache = True
def _init_weights(self, module):
std = self.config.init_std
@@ -481,24 +484,56 @@ def _init_weights(self, module):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
def _update_causal_mask(
self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
- output_attentions: bool = False,
):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
return attention_mask
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
@@ -506,7 +541,7 @@ def _update_causal_mask(
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@@ -540,7 +575,6 @@ def _update_causal_mask(
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
@@ -606,6 +640,42 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
class PegasusEncoder(PegasusPreTrainedModel):
"""
@@ -748,10 +818,10 @@ def forward(
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- # expand attention_mask
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@@ -970,26 +1040,32 @@ def forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-
- if input_ids is not None:
- input_ids = input_ids.view(-1, input_ids.shape[-1])
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
+ inputs_embeds = self.embed_tokens(input)
# important to apply scale outside of `if` in case users pass `embeds`
inputs_embeds = inputs_embeds * self.embed_scale
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
# initialize `past_key_values`
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache):
@@ -1018,20 +1094,19 @@ def forward(
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
+
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
self_attn_cache,
- output_attentions,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length
- )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
# embed positions
positions = self.embed_positions((batch_size, seq_length), past_key_values_length, position_ids=cache_position)
diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py
index 04cf37a7622e..a2fcf5edd1b6 100755
--- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py
+++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py
@@ -16,7 +16,7 @@
import dataclasses
import math
-from typing import Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
@@ -30,14 +30,17 @@
from ...modeling_attn_mask_utils import (
AttentionMaskConverter,
_prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
is_torch_flex_attn_available,
@@ -48,9 +51,7 @@
if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask
-
- from ...integrations.flex_attention import make_flex_block_causal_mask
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -135,6 +136,37 @@ def forward(
return pe[None].expand(batch_size, -1, -1)
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PegasusX
class PegasusXAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -187,17 +219,25 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
@@ -218,8 +258,8 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -231,66 +271,27 @@ def forward(
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
+ return attn_output, attn_weights, past_key_value
class PegasusXGlobalLocalAttention(nn.Module):
@@ -653,6 +654,7 @@ def __init__(self, config: PegasusXConfig, layer_idx: Optional[int] = None):
dropout=config.attention_dropout,
is_decoder=True,
bias=False,
+ config=config,
layer_idx=layer_idx,
)
self.dropout = config.dropout
@@ -666,6 +668,7 @@ def __init__(self, config: PegasusXConfig, layer_idx: Optional[int] = None):
dropout=config.attention_dropout,
is_decoder=True,
bias=False,
+ config=config,
layer_idx=layer_idx,
)
self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
@@ -758,6 +761,11 @@ class PegasusXPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"]
+ _supports_flash_attn_2 = True
+ # Flaky logits
+ _supports_sdpa = False
+ # Compile issues
+ _supports_flex_attn = False
_supports_cache_class = True
_supports_static_cache = True
@@ -773,24 +781,56 @@ def _init_weights(self, module):
module.weight.data.fill_(1.0)
module.bias.data.zero_()
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
def _update_causal_mask(
self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
input_tensor: torch.Tensor,
cache_position: torch.Tensor,
past_key_values: Cache,
- output_attentions: bool = False,
):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
if self.config._attn_implementation == "flex_attention":
if isinstance(attention_mask, torch.Tensor):
attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
return attention_mask
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
# to infer the attention mask.
@@ -798,7 +838,7 @@ def _update_causal_mask(
using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
if AttentionMaskConverter._ignore_causal_mask_sdpa(
attention_mask,
inputs_embeds=input_tensor,
@@ -832,7 +872,6 @@ def _update_causal_mask(
self.config._attn_implementation == "sdpa"
and attention_mask is not None
and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
):
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
@@ -898,6 +937,42 @@ def _prepare_4d_causal_attention_mask_with_cache_position(
return causal_mask
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
class PegasusXEncoder(PegasusXPreTrainedModel):
"""
@@ -1227,22 +1302,28 @@ def forward(
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- if self.gradient_checkpointing and self.training:
- if use_cache:
- logger.warning_once(
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
- )
- use_cache = False
-
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-
- if input_ids is not None:
- input_ids = input_ids.view(-1, input_ids.shape[-1])
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
+ inputs_embeds = self.embed_tokens(input)
+
+ if self.gradient_checkpointing and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..."
+ )
+ use_cache = False
# initialize `past_key_values`
return_legacy_cache = False
@@ -1272,20 +1353,19 @@ def forward(
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
+
causal_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
self_attn_cache,
- output_attentions,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length
- )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
# embed positions
position_ids = cache_position.unsqueeze(1)
diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py
index 614baee8bf32..85813b0242ca 100644
--- a/src/transformers/models/plbart/modeling_plbart.py
+++ b/src/transformers/models/plbart/modeling_plbart.py
@@ -1,3 +1,9 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/plbart/modular_plbart.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_plbart.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
#
@@ -12,14 +18,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""PyTorch PLBART model."""
import copy
import math
-from typing import List, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import torch
-import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
@@ -31,6 +35,7 @@
_prepare_4d_attention_mask,
_prepare_4d_attention_mask_for_sdpa,
)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -39,47 +44,247 @@
Seq2SeqModelOutput,
Seq2SeqSequenceClassifierOutput,
)
-from ...modeling_utils import PreTrainedModel
-from ...utils import (
- auto_docstring,
- is_torch_flex_attn_available,
- is_torchdynamo_compiling,
- logging,
-)
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, is_torchdynamo_compiling, logging
from .configuration_plbart import PLBartConfig
if is_torch_flex_attn_available():
- from torch.nn.attention.flex_attention import BlockMask
-
- from ...integrations.flex_attention import make_flex_block_causal_mask
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
logger = logging.get_logger(__name__)
-# Copied from transformers.models.mbart.modeling_mbart.shift_tokens_right
-def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
+class PLBartScaledWordEmbedding(nn.Embedding):
"""
- Shift input ids one token to the right, and wrap the last non pad token (the token) Note that MBart does not
- have a single `decoder_start_token_id` in contrast to other Bart-like models.
+ This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
"""
- prev_output_tokens = input_ids.clone()
- if pad_token_id is None:
- raise ValueError("self.model.config.pad_token_id has to be defined.")
- # replace possible -100 values in labels by `pad_token_id`
- prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
+ def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
+ super().__init__(num_embeddings, embedding_dim, padding_idx)
+ self.embed_scale = embed_scale
- index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
- decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
- prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
- prev_output_tokens[:, 0] = decoder_start_tokens
+ def forward(self, input_ids: torch.Tensor):
+ return super().forward(input_ids) * self.embed_scale
- return prev_output_tokens
+
+@auto_docstring
+class PLBartPreTrainedModel(PreTrainedModel):
+ config_class = PLBartConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ ):
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
+ return attention_mask
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
-# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->PLBart
class PLBartLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size.
@@ -105,21 +310,36 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0, posi
return super().forward(position_ids + self.offset)
-# Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PLBart
-class PLBartScaledWordEmbedding(nn.Embedding):
- """
- This module overrides nn.Embeddings' forward by multiplying with embeddings scale.
- """
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
- def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: Optional[float] = 1.0):
- super().__init__(num_embeddings, embedding_dim, padding_idx)
- self.embed_scale = embed_scale
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
- def forward(self, input_ids: torch.Tensor):
- return super().forward(input_ids) * self.embed_scale
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
-# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PLBart
class PLBartAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -171,17 +391,25 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
@@ -202,8 +430,8 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -215,75 +443,35 @@ def forward(
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
+ return attn_output, attn_weights, past_key_value
-# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->PLBart, BART->PLBART
class PLBartEncoderLayer(nn.Module):
def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = PLBartAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -333,320 +521,22 @@ def forward(
hidden_states = self.fc2(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
-
- if hidden_states.dtype == torch.float16 and (
- torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
- ):
- clamp_value = torch.finfo(hidden_states.dtype).max - 1000
- hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (attn_weights,)
-
- return outputs
-
-
-# TODO: Implement attention with SDPA for PLBart.
-PLBART_ATTENTION_CLASSES = {"eager": PLBartAttention}
-
-
-# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART
-class PLBartDecoderLayer(nn.Module):
- def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None):
- super().__init__()
- self.embed_dim = config.d_model
-
- self.self_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
- embed_dim=self.embed_dim,
- num_heads=config.decoder_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=True,
- is_causal=True,
- config=config,
- layer_idx=layer_idx,
- )
- self.dropout = config.dropout
- self.activation_fn = ACT2FN[config.activation_function]
- self.activation_dropout = config.activation_dropout
-
- self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = PLBART_ATTENTION_CLASSES[config._attn_implementation](
- self.embed_dim,
- config.decoder_attention_heads,
- dropout=config.attention_dropout,
- is_decoder=True,
- config=config,
- layer_idx=layer_idx,
- )
- self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
- self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
- self.final_layer_norm = nn.LayerNorm(self.embed_dim)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- attention_mask: Optional[torch.Tensor] = None,
- encoder_hidden_states: Optional[torch.Tensor] = None,
- encoder_attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
- past_key_value: Optional[Cache] = None,
- output_attentions: Optional[bool] = False,
- use_cache: Optional[bool] = True,
- cache_position: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
- """
- Args:
- hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
- attention_mask (`torch.FloatTensor`): attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- encoder_hidden_states (`torch.FloatTensor`):
- cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
- encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
- `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
- layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
- `(encoder_attention_heads,)`.
- cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
- size `(decoder_attention_heads,)`.
- past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
- output_attentions (`bool`, *optional*):
- Whether or not to return the attentions tensors of all attention layers. See `attentions` under
- returned tensors for more detail.
- cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
- Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
- cache in the correct position and to infer the complete sequence length.
- """
- residual = hidden_states
-
- # Self Attention
- hidden_states, self_attn_weights, past_key_value = self.self_attn(
- hidden_states=hidden_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- layer_head_mask=layer_head_mask,
- output_attentions=output_attentions,
- cache_position=cache_position,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.self_attn_layer_norm(hidden_states)
-
- # Cross-Attention Block
- cross_attn_weights = None
- if encoder_hidden_states is not None:
- residual = hidden_states
-
- hidden_states, cross_attn_weights, past_key_value = self.encoder_attn(
- hidden_states=hidden_states,
- key_value_states=encoder_hidden_states,
- attention_mask=encoder_attention_mask,
- layer_head_mask=cross_attn_layer_head_mask,
- past_key_value=past_key_value,
- output_attentions=output_attentions,
- )
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.encoder_attn_layer_norm(hidden_states)
-
- # Fully Connected
- residual = hidden_states
- hidden_states = self.activation_fn(self.fc1(hidden_states))
- hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
- hidden_states = self.fc2(hidden_states)
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- hidden_states = residual + hidden_states
- hidden_states = self.final_layer_norm(hidden_states)
-
- outputs = (hidden_states,)
-
- if output_attentions:
- outputs += (self_attn_weights, cross_attn_weights)
-
- if use_cache:
- outputs += (past_key_value,)
-
- return outputs
-
-
-# Copied from transformers.models.bart.modeling_bart.BartClassificationHead with Bart->PLBart
-class PLBartClassificationHead(nn.Module):
- """Head for sentence-level classification tasks."""
-
- def __init__(
- self,
- input_dim: int,
- inner_dim: int,
- num_classes: int,
- pooler_dropout: float,
- ):
- super().__init__()
- self.dense = nn.Linear(input_dim, inner_dim)
- self.dropout = nn.Dropout(p=pooler_dropout)
- self.out_proj = nn.Linear(inner_dim, num_classes)
-
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.dense(hidden_states)
- hidden_states = torch.tanh(hidden_states)
- hidden_states = self.dropout(hidden_states)
- hidden_states = self.out_proj(hidden_states)
- return hidden_states
-
-
-@auto_docstring
-class PLBartPreTrainedModel(PreTrainedModel):
- config_class = PLBartConfig
- base_model_prefix = "model"
- supports_gradient_checkpointing = True
- _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
- _supports_cache_class = True
- _supports_static_cache = True
-
- def _init_weights(self, module):
- std = self.config.init_std
- if isinstance(module, nn.Linear):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.bias is not None:
- module.bias.data.zero_()
- elif isinstance(module, nn.Embedding):
- module.weight.data.normal_(mean=0.0, std=std)
- if module.padding_idx is not None:
- module.weight.data[module.padding_idx].zero_()
- elif isinstance(module, nn.LayerNorm):
- module.weight.data.fill_(1.0)
- module.bias.data.zero_()
-
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
- def _update_causal_mask(
- self,
- attention_mask: Union[torch.Tensor, "BlockMask"],
- input_tensor: torch.Tensor,
- cache_position: torch.Tensor,
- past_key_values: Cache,
- output_attentions: bool = False,
- ):
- if self.config._attn_implementation == "flash_attention_2":
- if attention_mask is not None and (attention_mask == 0.0).any():
- return attention_mask
- return None
- if self.config._attn_implementation == "flex_attention":
- if isinstance(attention_mask, torch.Tensor):
- attention_mask = make_flex_block_causal_mask(attention_mask)
- return attention_mask
-
- # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
- # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
- # to infer the attention mask.
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
- using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
-
- # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
- if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
- if AttentionMaskConverter._ignore_causal_mask_sdpa(
- attention_mask,
- inputs_embeds=input_tensor,
- past_key_values_length=past_seen_tokens,
- is_training=self.training,
- ):
- return None
-
- dtype = input_tensor.dtype
- sequence_length = input_tensor.shape[1]
- if using_compilable_cache:
- target_length = past_key_values.get_max_cache_shape()
- else:
- target_length = (
- attention_mask.shape[-1]
- if isinstance(attention_mask, torch.Tensor)
- else past_seen_tokens + sequence_length + 1
- )
-
- # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
- causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask,
- sequence_length=sequence_length,
- target_length=target_length,
- dtype=dtype,
- cache_position=cache_position,
- batch_size=input_tensor.shape[0],
- )
-
- if (
- self.config._attn_implementation == "sdpa"
- and attention_mask is not None
- and attention_mask.device.type in ["cuda", "xpu", "npu"]
- and not output_attentions
- ):
- # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
- # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
- # Details: https://github.com/pytorch/pytorch/issues/110213
- min_dtype = torch.finfo(dtype).min
- causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
-
- return causal_mask
-
- @staticmethod
- # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
- def _prepare_4d_causal_attention_mask_with_cache_position(
- attention_mask: torch.Tensor,
- sequence_length: int,
- target_length: int,
- dtype: torch.dtype,
- cache_position: torch.Tensor,
- batch_size: int,
- **kwargs,
- ):
- """
- Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
- `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+ hidden_states = self.final_layer_norm(hidden_states)
- Args:
- attention_mask (`torch.Tensor`):
- A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
- `(batch_size, 1, query_length, key_value_length)`.
- sequence_length (`int`):
- The sequence length being processed.
- target_length (`int`):
- The target length: when generating with static cache, the mask should be as long as the static cache,
- to account for the 0 padding, the part of the cache that is not filled yet.
- dtype (`torch.dtype`):
- The dtype to use for the 4D attention mask.
- cache_position (`torch.Tensor`):
- Indices depicting the position of the input sequence tokens in the sequence.
- batch_size (`torch.Tensor`):
- Batch size.
- """
- if attention_mask is not None and attention_mask.dim() == 4:
- # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
- causal_mask = attention_mask
- else:
- min_dtype = torch.finfo(dtype).min
- causal_mask = torch.full(
- (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
- )
- if sequence_length != 1:
- causal_mask = torch.triu(causal_mask, diagonal=1)
- causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
- causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
- if attention_mask is not None:
- causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
- mask_length = attention_mask.shape[-1]
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
- causal_mask.device
- )
- padding_mask = padding_mask == 0
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
- padding_mask, min_dtype
- )
+ if hidden_states.dtype == torch.float16 and (
+ torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()
+ ):
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
- return causal_mask
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs
-# Copied from transformers.models.bart.modeling_bart.BartEncoder with Bart->PLBart
class PLBartEncoder(PLBartPreTrainedModel):
"""
Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
@@ -680,8 +570,6 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] =
embed_dim,
)
self.layers = nn.ModuleList([PLBartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)])
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
- self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(embed_dim)
self.gradient_checkpointing = False
@@ -767,18 +655,10 @@ def forward(
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- # expand attention_mask
- if attention_mask is not None:
- if self._use_flash_attention_2:
- attention_mask = attention_mask if 0 in attention_mask else None
- elif self._use_sdpa and head_mask is None and not output_attentions:
- # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
- # the manual implementation that requires a 4D causal mask in all cases.
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
- else:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@@ -835,7 +715,125 @@ def forward(
)
-# Copied from transformers.models.bart.modeling_bart.BartDecoder with Bart->PLBart
+class PLBartDecoderLayer(nn.Module):
+ def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None):
+ super().__init__()
+ self.embed_dim = config.d_model
+
+ self.self_attn = PLBartAttention(
+ embed_dim=self.embed_dim,
+ num_heads=config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ is_causal=True,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.dropout = config.dropout
+ self.activation_fn = ACT2FN[config.activation_function]
+ self.activation_dropout = config.activation_dropout
+
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.encoder_attn = PLBartAttention(
+ self.embed_dim,
+ config.decoder_attention_heads,
+ dropout=config.attention_dropout,
+ is_decoder=True,
+ config=config,
+ layer_idx=layer_idx,
+ )
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim)
+ self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim)
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ layer_head_mask: Optional[torch.Tensor] = None,
+ cross_attn_layer_head_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Cache] = None,
+ output_attentions: Optional[bool] = False,
+ use_cache: Optional[bool] = True,
+ cache_position: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
+ """
+ Args:
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
+ attention_mask (`torch.FloatTensor`): attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ encoder_hidden_states (`torch.FloatTensor`):
+ cross attention input to the layer of shape `(batch, seq_len, embed_dim)`
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
+ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
+ layer_head_mask (`torch.FloatTensor`): mask for attention heads in a given layer of size
+ `(encoder_attention_heads,)`.
+ cross_attn_layer_head_mask (`torch.FloatTensor`): mask for cross-attention heads in a given layer of
+ size `(decoder_attention_heads,)`.
+ past_key_value (`Tuple(torch.FloatTensor)`): cached past key and value projection states
+ output_attentions (`bool`, *optional*):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more detail.
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
+ Indices depicting the position of the input sequence tokens in the sequence. It is used to update the
+ cache in the correct position and to infer the complete sequence length.
+ """
+ residual = hidden_states
+
+ # Self Attention
+ hidden_states, self_attn_weights, past_key_value = self.self_attn(
+ hidden_states=hidden_states,
+ past_key_value=past_key_value,
+ attention_mask=attention_mask,
+ layer_head_mask=layer_head_mask,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.self_attn_layer_norm(hidden_states)
+
+ # Cross-Attention Block
+ cross_attn_weights = None
+ if encoder_hidden_states is not None:
+ residual = hidden_states
+
+ hidden_states, cross_attn_weights, past_key_value = self.encoder_attn(
+ hidden_states=hidden_states,
+ key_value_states=encoder_hidden_states,
+ attention_mask=encoder_attention_mask,
+ layer_head_mask=cross_attn_layer_head_mask,
+ past_key_value=past_key_value,
+ output_attentions=output_attentions,
+ cache_position=cache_position,
+ )
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
+
+ # Fully Connected
+ residual = hidden_states
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
+ hidden_states = self.fc2(hidden_states)
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
+ hidden_states = residual + hidden_states
+ hidden_states = self.final_layer_norm(hidden_states)
+
+ outputs = (hidden_states,)
+
+ if output_attentions:
+ outputs += (self_attn_weights, cross_attn_weights)
+
+ if use_cache:
+ outputs += (past_key_value,)
+
+ return outputs
+
+
class PLBartDecoder(PLBartPreTrainedModel):
"""
Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`PLBartDecoderLayer`]
@@ -865,8 +863,6 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] =
config.d_model,
)
self.layers = nn.ModuleList([PLBartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)])
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
- self._use_sdpa = config._attn_implementation == "sdpa"
self.layernorm_embedding = nn.LayerNorm(config.d_model)
@@ -981,12 +977,18 @@ def forward(
# retrieve input_ids and inputs_embeds
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
-
- if input_ids is not None:
- input_ids = input_ids.view(-1, input_ids.shape[-1])
+ elif input_ids is not None:
+ input = input_ids
+ input_shape = input.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.size()[:-1]
+ input = inputs_embeds[:, :, -1]
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
if inputs_embeds is None:
- inputs_embeds = self.embed_tokens(input_ids)
+ inputs_embeds = self.embed_tokens(input)
# initialize `past_key_values`
return_legacy_cache = False
@@ -1016,38 +1018,25 @@ def forward(
if isinstance(past_key_values, EncoderDecoderCache)
else past_key_values
)
- causal_mask = self._update_causal_mask(
+
+ attention_mask = self._update_causal_mask(
attention_mask,
inputs_embeds,
cache_position,
self_attn_cache,
- output_attentions,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- if self._use_flash_attention_2:
- encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
- elif self._use_sdpa and cross_attn_head_mask is None and not output_attentions:
- # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
- # the manual implementation that requires a 4D causal mask in all cases.
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
- encoder_attention_mask,
- inputs_embeds.dtype,
- tgt_len=seq_length,
- )
- else:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length
- )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
+ )
# embed positions
- position_ids = self.embed_positions(input, past_key_values_length, position_ids=cache_position)
- position_ids = position_ids.to(inputs_embeds.device)
+ positions = self.embed_positions(input, past_key_values_length, position_ids=cache_position)
+ positions = positions.to(inputs_embeds.device)
- hidden_states = inputs_embeds + position_ids
+ hidden_states = inputs_embeds + positions
hidden_states = self.layernorm_embedding(hidden_states)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
@@ -1080,7 +1069,7 @@ def forward(
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
- causal_mask,
+ attention_mask,
encoder_hidden_states,
encoder_attention_mask,
head_mask[idx] if head_mask is not None else None,
@@ -1093,7 +1082,7 @@ def forward(
else:
layer_outputs = decoder_layer(
hidden_states,
- attention_mask=causal_mask,
+ attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
layer_head_mask=(head_mask[idx] if head_mask is not None else None),
@@ -1139,6 +1128,26 @@ def forward(
)
+def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int):
+ """
+ Shift input ids one token to the right, and wrap the last non pad token (the token) Note that PLBart does not
+ have a single `decoder_start_token_id` in contrast to other Bart-like models.
+ """
+ prev_output_tokens = input_ids.clone()
+
+ if pad_token_id is None:
+ raise ValueError("self.model.config.pad_token_id has to be defined.")
+ # replace possible -100 values in labels by `pad_token_id`
+ prev_output_tokens.masked_fill_(prev_output_tokens == -100, pad_token_id)
+
+ index_of_eos = (prev_output_tokens.ne(pad_token_id).sum(dim=1) - 1).unsqueeze(-1)
+ decoder_start_tokens = prev_output_tokens.gather(1, index_of_eos).squeeze()
+ prev_output_tokens[:, 1:] = prev_output_tokens[:, :-1].clone()
+ prev_output_tokens[:, 0] = decoder_start_tokens
+
+ return prev_output_tokens
+
+
@auto_docstring
class PLBartModel(PLBartPreTrainedModel):
_tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
@@ -1192,7 +1201,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- cache_position: Optional[torch.Tensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1350,7 +1359,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
- cache_position: Optional[torch.Tensor] = None,
+ cache_position: Optional[torch.LongTensor] = None,
) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
r"""
decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1469,10 +1478,34 @@ def _reorder_cache(past_key_values, beam_idx):
return reordered_past
+class PLBartClassificationHead(nn.Module):
+ """Head for sentence-level classification tasks."""
+
+ def __init__(
+ self,
+ input_dim: int,
+ inner_dim: int,
+ num_classes: int,
+ pooler_dropout: float,
+ ):
+ super().__init__()
+ self.dense = nn.Linear(input_dim, inner_dim)
+ self.dropout = nn.Dropout(p=pooler_dropout)
+ self.out_proj = nn.Linear(inner_dim, num_classes)
+
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.dense(hidden_states)
+ hidden_states = torch.tanh(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.out_proj(hidden_states)
+ return hidden_states
+
+
@auto_docstring(
custom_intro="""
- PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for code
- classification.
+ PLBart model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g.
+ for GLUE tasks.
"""
)
class PLBartForSequenceClassification(PLBartPreTrainedModel):
@@ -1492,7 +1525,6 @@ def __init__(self, config: PLBartConfig, **kwargs):
self.post_init()
@auto_docstring
- # Ignore copy
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
@@ -1621,7 +1653,6 @@ def forward(
)
-# Copied from transformers.models.bart.modeling_bart.BartDecoderWrapper with Bart->PLBart
class PLBartDecoderWrapper(PLBartPreTrainedModel):
"""
This wrapper class is a helper class to correctly load pretrained checkpoints when the causal language model is
@@ -1636,7 +1667,11 @@ def forward(self, *args, **kwargs):
return self.decoder(*args, **kwargs)
-# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base
+@auto_docstring(
+ custom_intro="""
+ PLBART decoder with a language modeling head on top (linear layer with weights tied to the input embeddings).
+ """
+)
class PLBartForCausalLM(PLBartPreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"]
diff --git a/src/transformers/models/plbart/modular_plbart.py b/src/transformers/models/plbart/modular_plbart.py
new file mode 100644
index 000000000000..8a9755b11b17
--- /dev/null
+++ b/src/transformers/models/plbart/modular_plbart.py
@@ -0,0 +1,692 @@
+# coding=utf-8
+# Copyright 2022, UCLA NLP, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch PLBART model."""
+
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+from torch.nn import CrossEntropyLoss
+
+from ...cache_utils import Cache
+from ...generation import GenerationMixin
+from ...modeling_attn_mask_utils import (
+ AttentionMaskConverter,
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+)
+from ...modeling_outputs import (
+ BaseModelOutput,
+ Seq2SeqLMOutput,
+ Seq2SeqModelOutput,
+)
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring, is_torch_flex_attn_available
+from ..bart.modeling_bart import (
+ BartClassificationHead,
+ BartDecoder,
+ BartEncoder,
+ BartForCausalLM,
+ BartScaledWordEmbedding,
+)
+from ..bigbird_pegasus.modeling_bigbird_pegasus import BigBirdPegasusForSequenceClassification
+from ..mbart.modeling_mbart import shift_tokens_right
+from .configuration_plbart import PLBartConfig
+
+
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import BlockMask, make_flex_block_causal_mask
+
+
+class PLBartScaledWordEmbedding(BartScaledWordEmbedding):
+ pass
+
+
+@auto_docstring
+class PLBartPreTrainedModel(PreTrainedModel):
+ config_class = PLBartConfig
+ base_model_prefix = "model"
+ supports_gradient_checkpointing = True
+ _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"]
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
+
+ def _init_weights(self, module):
+ std = self.config.init_std
+ if isinstance(module, nn.Linear):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, nn.Embedding):
+ module.weight.data.normal_(mean=0.0, std=std)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Optional[Union[torch.Tensor, "BlockMask"]],
+ input_tensor: torch.Tensor,
+ cache_position: torch.Tensor,
+ past_key_values: Cache,
+ ):
+ if self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_tensor.shape[0], input_tensor.shape[1]),
+ device=attention_mask.device,
+ )
+ )
+ return attention_mask
+
+ if self.config._attn_implementation == "flash_attention_2":
+ if attention_mask is not None and (attention_mask == 0.0).any():
+ return attention_mask
+ return None
+
+ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
+ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
+ # to infer the attention mask.
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
+ using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
+
+ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
+ if self.config._attn_implementation == "sdpa" and not using_compilable_cache:
+ if AttentionMaskConverter._ignore_causal_mask_sdpa(
+ attention_mask,
+ inputs_embeds=input_tensor,
+ past_key_values_length=past_seen_tokens,
+ is_training=self.training,
+ ):
+ return None
+
+ dtype = input_tensor.dtype
+ sequence_length = input_tensor.shape[1]
+ if using_compilable_cache:
+ target_length = past_key_values.get_max_cache_shape()
+ else:
+ target_length = (
+ attention_mask.shape[-1]
+ if isinstance(attention_mask, torch.Tensor)
+ else past_seen_tokens + sequence_length + 1
+ )
+
+ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
+ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask,
+ sequence_length=sequence_length,
+ target_length=target_length,
+ dtype=dtype,
+ cache_position=cache_position,
+ batch_size=input_tensor.shape[0],
+ )
+
+ if (
+ self.config._attn_implementation == "sdpa"
+ and attention_mask is not None
+ and attention_mask.device.type in ["cuda", "xpu", "npu"]
+ ):
+ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
+ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
+ # Details: https://github.com/pytorch/pytorch/issues/110213
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
+
+ return causal_mask
+
+ @staticmethod
+ # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
+ def _prepare_4d_causal_attention_mask_with_cache_position(
+ attention_mask: torch.Tensor,
+ sequence_length: int,
+ target_length: int,
+ dtype: torch.dtype,
+ cache_position: torch.Tensor,
+ batch_size: int,
+ **kwargs,
+ ):
+ """
+ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
+ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
+
+ Args:
+ attention_mask (`torch.Tensor`):
+ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
+ `(batch_size, 1, query_length, key_value_length)`.
+ sequence_length (`int`):
+ The sequence length being processed.
+ target_length (`int`):
+ The target length: when generating with static cache, the mask should be as long as the static cache,
+ to account for the 0 padding, the part of the cache that is not filled yet.
+ dtype (`torch.dtype`):
+ The dtype to use for the 4D attention mask.
+ cache_position (`torch.Tensor`):
+ Indices depicting the position of the input sequence tokens in the sequence.
+ batch_size (`torch.Tensor`):
+ Batch size.
+ """
+ if attention_mask is not None and attention_mask.dim() == 4:
+ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
+ causal_mask = attention_mask
+ else:
+ min_dtype = torch.finfo(dtype).min
+ causal_mask = torch.full(
+ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
+ )
+ if sequence_length != 1:
+ causal_mask = torch.triu(causal_mask, diagonal=1)
+ causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
+ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
+ if attention_mask is not None:
+ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
+ mask_length = attention_mask.shape[-1]
+ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
+ causal_mask.device
+ )
+ padding_mask = padding_mask == 0
+ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
+ padding_mask, min_dtype
+ )
+
+ return causal_mask
+
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
+
+class PLBartEncoder(BartEncoder):
+ pass
+
+
+class PLBartDecoder(BartDecoder):
+ pass
+
+
+@auto_docstring
+class PLBartModel(PLBartPreTrainedModel):
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
+
+ def __init__(self, config: PLBartConfig):
+ super().__init__(config)
+
+ padding_idx, vocab_size = config.pad_token_id, config.vocab_size
+ embed_scale = math.sqrt(config.d_model) if config.scale_embedding else 1.0
+ self.shared = PLBartScaledWordEmbedding(vocab_size, config.d_model, padding_idx, embed_scale=embed_scale)
+
+ self.encoder = PLBartEncoder(config, self.shared)
+ self.decoder = PLBartDecoder(config, self.shared)
+
+ self.init_weights()
+
+ def get_input_embeddings(self):
+ return self.shared
+
+ def set_input_embeddings(self, value):
+ self.shared = value
+ self.encoder.embed_tokens = self.shared
+ self.decoder.embed_tokens = self.shared
+
+ def _tie_weights(self):
+ if self.config.tie_word_embeddings:
+ self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
+ self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
+
+ def get_encoder(self):
+ return self.encoder
+
+ def get_decoder(self):
+ return self.decoder
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.LongTensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
+ See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
+ varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
+ for denoising pre-training following the paper.
+ decoder_attention_mask (:
+ obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior:
+ generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
+ cross_attn_head_mask (:
+ obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify
+ selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # different to other models, PLBart automatically creates decoder_input_ids from
+ # input_ids if no decoder_input_ids are provided
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(input_ids, self.config.pad_token_id)
+
+ if encoder_outputs is None:
+ encoder_outputs = self.encoder(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ # If the user passed a tuple for encoder_outputs, we wrap it in a BaseModelOutput when return_dict=True
+ elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
+ encoder_outputs = BaseModelOutput(
+ last_hidden_state=encoder_outputs[0],
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
+ )
+
+ # decoder outputs consists of (dec_features, past_key_value, dec_hidden, dec_attn)
+ decoder_outputs = self.decoder(
+ input_ids=decoder_input_ids,
+ attention_mask=decoder_attention_mask,
+ encoder_hidden_states=encoder_outputs[0],
+ encoder_attention_mask=attention_mask,
+ head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+
+ if not return_dict:
+ return decoder_outputs + encoder_outputs
+
+ return Seq2SeqModelOutput(
+ last_hidden_state=decoder_outputs.last_hidden_state,
+ past_key_values=decoder_outputs.past_key_values,
+ decoder_hidden_states=decoder_outputs.hidden_states,
+ decoder_attentions=decoder_outputs.attentions,
+ cross_attentions=decoder_outputs.cross_attentions,
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
+ encoder_hidden_states=encoder_outputs.hidden_states,
+ encoder_attentions=encoder_outputs.attentions,
+ )
+
+
+@auto_docstring(
+ custom_intro="""
+ The PLBART Model with a language modeling head. Can be used for code-to-text, text-to-code and code-to-code.
+ """
+)
+class PLBartForConditionalGeneration(PLBartPreTrainedModel, GenerationMixin):
+ base_model_prefix = "model"
+ _keys_to_ignore_on_load_missing = ["final_logits_bias"]
+ _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
+
+ def __init__(self, config: PLBartConfig):
+ super().__init__(config)
+ self.model = PLBartModel(config)
+ self.register_buffer("final_logits_bias", torch.zeros((1, self.model.shared.num_embeddings)))
+ self.lm_head = nn.Linear(config.d_model, self.model.shared.num_embeddings, bias=False)
+
+ self.init_weights()
+
+ def get_encoder(self):
+ return self.model.get_encoder()
+
+ def get_decoder(self):
+ return self.model.get_decoder()
+
+ def resize_token_embeddings(
+ self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, mean_resizing: bool = True
+ ) -> nn.Embedding:
+ new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
+ self._resize_final_logits_bias(new_embeddings.weight.shape[0])
+ return new_embeddings
+
+ def _resize_final_logits_bias(self, new_num_tokens: int) -> None:
+ old_num_tokens = self.final_logits_bias.shape[-1]
+ if new_num_tokens <= old_num_tokens:
+ new_bias = self.final_logits_bias[:, :new_num_tokens]
+ else:
+ extra_bias = torch.zeros((1, new_num_tokens - old_num_tokens), device=self.final_logits_bias.device)
+ new_bias = torch.cat([self.final_logits_bias, extra_bias], dim=1)
+ self.register_buffer("final_logits_bias", new_bias)
+
+ def get_output_embeddings(self):
+ return self.lm_head
+
+ def set_output_embeddings(self, new_embeddings):
+ self.lm_head = new_embeddings
+
+ @auto_docstring
+ def forward(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ decoder_input_ids: Optional[torch.LongTensor] = None,
+ decoder_attention_mask: Optional[torch.Tensor] = None,
+ head_mask: Optional[torch.Tensor] = None,
+ decoder_head_mask: Optional[torch.LongTensor] = None,
+ cross_attn_head_mask: Optional[torch.Tensor] = None,
+ encoder_outputs: Optional[List[torch.FloatTensor]] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.Tensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]:
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
+ See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
+ varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
+ for denoising pre-training following the paper.
+ decoder_attention_mask (:
+ obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior:
+ generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
+ cross_attn_head_mask (:
+ obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify
+ selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example Mask-filling:
+
+ ```python
+ >>> from transformers import AutoTokenizer, PLBartForConditionalGeneration
+
+ >>> model = PLBartForConditionalGeneration.from_pretrained("uclanlp/plbart-base")
+ >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
+
+ >>> # en_XX is the language symbol id for English
+ >>> TXT = " Is 0 the Fibonacci number ? en_XX"
+ >>> input_ids = tokenizer([TXT], add_special_tokens=False, return_tensors="pt").input_ids
+
+ >>> logits = model(input_ids).logits
+ >>> masked_index = (input_ids[0] == tokenizer.mask_token_id).nonzero().item()
+ >>> probs = logits[0, masked_index].softmax(dim=0)
+ >>> values, predictions = probs.topk(5)
+
+ >>> tokenizer.decode(predictions).split()
+ ['first', 'same', 'highest', 'result', 'number']
+ ```
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if labels is not None:
+ if decoder_input_ids is None and decoder_inputs_embeds is None:
+ decoder_input_ids = shift_tokens_right(labels, self.config.pad_token_id)
+
+ outputs = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ decoder_input_ids=decoder_input_ids,
+ encoder_outputs=encoder_outputs,
+ decoder_attention_mask=decoder_attention_mask,
+ head_mask=head_mask,
+ decoder_head_mask=decoder_head_mask,
+ cross_attn_head_mask=cross_attn_head_mask,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ decoder_inputs_embeds=decoder_inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ cache_position=cache_position,
+ )
+ lm_logits = self.lm_head(outputs[0])
+ lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device)
+
+ masked_lm_loss = None
+ if labels is not None:
+ loss_fct = CrossEntropyLoss()
+ masked_lm_loss = loss_fct(lm_logits.view(-1, self.config.vocab_size), labels.view(-1))
+
+ if not return_dict:
+ output = (lm_logits,) + outputs[1:]
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
+
+ return Seq2SeqLMOutput(
+ loss=masked_lm_loss,
+ logits=lm_logits,
+ past_key_values=outputs.past_key_values,
+ decoder_hidden_states=outputs.decoder_hidden_states,
+ decoder_attentions=outputs.decoder_attentions,
+ cross_attentions=outputs.cross_attentions,
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
+ encoder_hidden_states=outputs.encoder_hidden_states,
+ encoder_attentions=outputs.encoder_attentions,
+ )
+
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
+ return shift_tokens_right(labels, self.config.pad_token_id)
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ # cached cross_attention states don't have to be reordered -> they are always the same
+ reordered_past += (
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past[:2])
+ + layer_past[2:],
+ )
+ return reordered_past
+
+
+class PLBartClassificationHead(BartClassificationHead):
+ pass
+
+
+class PLBartForSequenceClassification(BigBirdPegasusForSequenceClassification):
+ def forward(**super_kwargs):
+ r"""
+ decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
+ Indices of decoder input sequence tokens in the vocabulary.
+
+ Indices can be obtained using [`AutoTokenizer`] or [`PLBartMultiTokenizer`] depending on the checkpoint.
+ See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
+
+ [What are decoder input IDs?](../glossary#decoder-input-ids)
+
+ PLBart uses a specific language id token as the starting token for `decoder_input_ids` generation that
+ varies according to source and target language, *e.g.* 50003 for *en_XX*, and 50001 for *java*. If
+ `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
+ `past_key_values`).
+
+ For translation and summarization training, `decoder_input_ids` should be provided. If no
+ `decoder_input_ids` is provided, the model will create this tensor by shifting the `input_ids` to the right
+ for denoising pre-training following the paper.
+ decoder_attention_mask (:
+ obj:*torch.LongTensor* of shape `(batch_size, target_sequence_length)`, *optional*):
+ Default behavior:
+ generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also be used by default.
+ cross_attn_head_mask (:
+ obj:*torch.Tensor* of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify
+ selected heads of the cross-attention modules in the decoder. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ super().forward(**super_kwargs)
+
+
+class PLBartForCausalLM(BartForCausalLM):
+ @auto_docstring
+ def forward(**super_kwargs):
+ r"""
+ cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*):
+ Mask to nullify selected heads of the cross-attention modules. Mask values selected in `[0, 1]`:
+
+ - 1 indicates the head is **not masked**,
+ - 0 indicates the head is **masked**.
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, PLBartForCausalLM
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("uclanlp/plbart-base")
+ >>> model = PLBartForCausalLM.from_pretrained("uclanlp/plbart-base", add_cross_attention=False)
+ >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
+ >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
+ >>> outputs = model(**inputs)
+
+ >>> logits = outputs.logits
+ >>> expected_shape = [1, inputs.input_ids.shape[-1], model.config.vocab_size]
+ >>> list(logits.shape) == expected_shape
+ True
+ ```"""
+ super().forward(**super_kwargs)
+
+
+__all__ = [
+ "PLBartForCausalLM",
+ "PLBartForConditionalGeneration",
+ "PLBartForSequenceClassification",
+ "PLBartModel",
+ "PLBartPreTrainedModel",
+]
diff --git a/src/transformers/models/sew/feature_extractor_sew.py b/src/transformers/models/sew/feature_extractor_sew.py
new file mode 100644
index 000000000000..c58812b58d6b
--- /dev/null
+++ b/src/transformers/models/sew/feature_extractor_sew.py
@@ -0,0 +1,34 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/sew/modular_sew.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_sew.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import warnings
+
+from .modeling_sew import SEWFeatureEncoder
+
+
+class SEWFeatureExtractor(SEWFeatureEncoder):
+ def __init__(self, config):
+ super().__init__(config)
+ warnings.warn(
+ f"The class `{self.__class__.__name__}` has been depreciated "
+ "and will be removed in Transformers v5. "
+ f"Use `{self.__class__.__bases__[0].__name__}` instead.",
+ FutureWarning,
+ )
diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py
index b4d8584fe0ce..330cd99a7b4e 100644
--- a/src/transformers/models/sew/modeling_sew.py
+++ b/src/transformers/models/sew/modeling_sew.py
@@ -1,3 +1,9 @@
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# This file was automatically generated from src/transformers/models/sew/modular_sew.py.
+# Do NOT edit this file manually as any edits will be overwritten by the generation of
+# the file from the modular. If any change should be done, please apply the change to the
+# modular_sew.py file directly. One of our CI enforces this.
+# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
# coding=utf-8
# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved.
#
@@ -12,160 +18,30 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-"""PyTorch SEW model."""
import math
import warnings
-from typing import Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
-import torch.utils.checkpoint
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
-from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...utils import auto_docstring, logging
-from ...utils.deprecation import deprecate_kwarg
from .configuration_sew import SEWConfig
-if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
-
-
logger = logging.get_logger(__name__)
-_HIDDEN_STATES_START_POSITION = 1
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2._compute_mask_indices
-def _compute_mask_indices(
- shape: Tuple[int, int],
- mask_prob: float,
- mask_length: int,
- attention_mask: Optional[torch.LongTensor] = None,
- min_masks: int = 0,
-) -> np.ndarray:
- """
- Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
- ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
- CPU as part of the preprocessing during training.
-
- Args:
- shape: The shape for which to compute masks. This should be of a tuple of size 2 where
- the first element is the batch size and the second element is the length of the axis to span.
- mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
- independently generated mask spans of length `mask_length` is computed by
- `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
- actual percentage will be smaller.
- mask_length: size of the mask
- min_masks: minimum number of masked spans
- attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
- each batch dimension.
- """
- batch_size, sequence_length = shape
-
- if mask_length < 1:
- raise ValueError("`mask_length` has to be bigger than 0.")
-
- if mask_length > sequence_length:
- raise ValueError(
- f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
- f" and `sequence_length`: {sequence_length}`"
- )
-
- # epsilon is used for probabilistic rounding
- epsilon = np.random.rand(1).item()
-
- def compute_num_masked_span(input_length):
- """Given input length, compute how many spans should be masked"""
- num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
- num_masked_span = max(num_masked_span, min_masks)
-
- # make sure num masked span <= sequence_length
- if num_masked_span * mask_length > sequence_length:
- num_masked_span = sequence_length // mask_length
-
- # make sure num_masked span is also <= input_length - (mask_length - 1)
- if input_length - (mask_length - 1) < num_masked_span:
- num_masked_span = max(input_length - (mask_length - 1), 0)
-
- return num_masked_span
-
- # compute number of masked spans in batch
- input_lengths = (
- attention_mask.detach().sum(-1).tolist()
- if attention_mask is not None
- else [sequence_length for _ in range(batch_size)]
- )
-
- # SpecAugment mask to fill
- spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
- spec_aug_mask_idxs = []
-
- max_num_masked_span = compute_num_masked_span(sequence_length)
-
- if max_num_masked_span == 0:
- return spec_aug_mask
-
- for input_length in input_lengths:
- # compute num of masked spans for this input
- num_masked_span = compute_num_masked_span(input_length)
-
- # get random indices to mask
- spec_aug_mask_idx = np.random.choice(
- np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
- )
-
- # pick first sampled index that will serve as a dummy index to pad vector
- # to ensure same dimension for all batches due to probabilistic rounding
- # Picking first sample just pads those vectors twice.
- if len(spec_aug_mask_idx) == 0:
- # this case can only happen if `input_length` is strictly smaller then
- # `sequence_length` in which case the last token has to be a padding
- # token which we can use as a dummy mask id
- dummy_mask_idx = sequence_length - 1
- else:
- dummy_mask_idx = spec_aug_mask_idx[0]
-
- spec_aug_mask_idx = np.concatenate(
- [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
- )
- spec_aug_mask_idxs.append(spec_aug_mask_idx)
-
- spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
-
- # expand masked indices to masked spans
- spec_aug_mask_idxs = np.broadcast_to(
- spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
-
- # add offset to the starting indexes so that indexes now create a span
- offsets = np.arange(mask_length)[None, None, :]
- offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
- batch_size, max_num_masked_span * mask_length
- )
- spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
-
- # ensure that we cannot have indices larger than sequence_length
- if spec_aug_mask_idxs.max() > sequence_length - 1:
- spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
-
- # scatter indices to mask
- np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
-
- return spec_aug_mask
-
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2NoLayerNormConvLayer with Wav2Vec2->SEW
class SEWNoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@@ -187,7 +63,6 @@ def forward(self, hidden_states):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2LayerNormConvLayer with Wav2Vec2->SEW
class SEWLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@@ -215,7 +90,6 @@ def forward(self, hidden_states):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2GroupNormConvLayer with Wav2Vec2->SEW
class SEWGroupNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@@ -283,7 +157,6 @@ def forward(self, hidden_states):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2SamePadLayer with Wav2Vec2->SEW
class SEWSamePadLayer(nn.Module):
def __init__(self, num_conv_pos_embeddings):
super().__init__()
@@ -317,7 +190,6 @@ def forward(self, hidden_states):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeatureEncoder with Wav2Vec2->SEW
class SEWFeatureEncoder(nn.Module):
"""Construct the features from raw audio waveform"""
@@ -362,18 +234,36 @@ def forward(self, input_values):
return hidden_states
-class SEWFeatureExtractor(SEWFeatureEncoder):
- def __init__(self, config):
- super().__init__(config)
- warnings.warn(
- f"The class `{self.__class__.__name__}` has been depreciated "
- "and will be removed in Transformers v5. "
- f"Use `{self.__class__.__bases__[0].__name__}` instead.",
- FutureWarning,
- )
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
-# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->SEW
class SEWAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -386,7 +276,6 @@ def __init__(
bias: bool = True,
is_causal: bool = False,
config: Optional[SEWConfig] = None,
- layer_idx: Optional[int] = None,
):
super().__init__()
self.embed_dim = embed_dim
@@ -403,23 +292,12 @@ def __init__(
self.scaling = self.head_dim**-0.5
self.is_decoder = is_decoder
self.is_causal = is_causal
- self.layer_idx = layer_idx
- if layer_idx is None and self.is_decoder:
- logger.warning_once(
- f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
- "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
- "when creating this class."
- )
self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- # Ignore copy
- @deprecate_kwarg("key_value_states", version="4.55")
- @deprecate_kwarg("past_key_value", version="4.55")
- @deprecate_kwarg("cache_position", version="4.55")
def forward(
self,
hidden_states: torch.Tensor,
@@ -428,253 +306,86 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
- cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
-
- key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
+ # if key_value_states are provided this layer is used as a cross-attention layer
+ # for the decoder
+ is_cross_attention = key_value_states is not None
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
- return attn_output, attn_weights_reshaped, None
-
-
-# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->SEW
-class SEWFlashAttention2(SEWAttention):
- """
- SEW flash attention module. This module inherits from `SEWAttention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
-
- # Ignore copy
- @deprecate_kwarg("key_value_states", version="4.55")
- @deprecate_kwarg("past_key_value", version="4.55")
- @deprecate_kwarg("cache_position", version="4.55")
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- cache_position: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # SEWFlashAttention2 attention does not support output_attentions
- if output_attentions:
- raise ValueError(
- "SEWSdpaAttention2 attention does not support `output_attentions`. "
- "Use the argument `attn_implementation='eager'` when loading the model."
- )
-
- bsz, q_len, _ = hidden_states.size()
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim)
-
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim)
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (LlamaRMSNorm handles it correctly)
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = _flash_attention_forward(
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
+ # get key, value proj
+ # `past_key_value[0].shape[2] == key_value_states.shape[1]`
+ # is checking that the `sequence_length` of the `past_key_value` is the same as
+ # the provided `key_value_states` to support prefix tuning
+ if (
+ is_cross_attention
+ and past_key_value is not None
+ and past_key_value[0].shape[2] == key_value_states.shape[1]
+ ):
+ # reuse k,v, cross_attentions
+ key_states = past_key_value[0]
+ value_states = past_key_value[1]
+ elif is_cross_attention:
+ # cross_attentions
+ key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ elif past_key_value is not None:
+ # reuse k, v, self_attention
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ key_states = torch.cat([past_key_value[0], key_states], dim=2)
+ value_states = torch.cat([past_key_value[1], value_states], dim=2)
+ else:
+ # self_attention
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+
+ if self.is_decoder:
+ # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
+ # Further calls to cross_attention layer can then reuse all cross-attention
+ # key/value_states (first "if" case)
+ # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
+ # all previous decoder key/value_states. Further calls to uni-directional self-attention
+ # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
+ # if encoder bi-directional self-attention `past_key_value` is always `None`
+ past_key_value = (key_states, value_states)
+
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
query_states,
key_states,
value_states,
attention_mask,
- q_len,
- dropout=self.dropout if self.training else 0.0,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1)
- attn_output = self.out_proj(attn_output)
-
- return attn_output, None, None
-
-
-class SEWSdpaAttention(SEWAttention):
- @deprecate_kwarg("key_value_states", version="4.55")
- @deprecate_kwarg("past_key_value", version="4.55")
- @deprecate_kwarg("cache_position", version="4.55")
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- cache_position: Optional[torch.Tensor] = None,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "SEWModel is using SEWSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- )
-
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- key_states = self.k_proj(hidden_states)
- value_states = self.v_proj(hidden_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
-
- causal_mask = None
- if attention_mask is not None: # no matter the length, we just slice it
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
-
- # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
- # Reference: https://github.com/pytorch/pytorch/issues/112577.
- if query_states.device.type == "cuda" and causal_mask is not None:
- query_states = query_states.contiguous()
- key_states = key_states.contiguous()
- value_states = value_states.contiguous()
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
- is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False
-
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
- attn_output = torch.nn.functional.scaled_dot_product_attention(
- query_states,
- key_states,
- value_states,
- attn_mask=causal_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
)
- attn_output = attn_output.transpose(1, 2).contiguous()
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, None, None
-
+ return attn_output, attn_weights, past_key_value
-SEW_ATTENTION_CLASSES = {
- "eager": SEWAttention,
- "sdpa": SEWSdpaAttention,
- "flash_attention_2": SEWFlashAttention2,
-}
-
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2FeedForward with Wav2Vec2->SEW
class SEWFeedForward(nn.Module):
def __init__(self, config):
super().__init__()
@@ -699,15 +410,15 @@ def forward(self, hidden_states):
return hidden_states
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2EncoderLayer with Wav2Vec2->SEW, WAV2VEC2->SEW
class SEWEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
- self.attention = SEW_ATTENTION_CLASSES[config._attn_implementation](
+ self.attention = SEWAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
+ config=config,
)
self.dropout = nn.Dropout(config.hidden_dropout)
@@ -746,7 +457,6 @@ def __init__(self, config):
self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.upsample = SEWUpsampling(config)
self.gradient_checkpointing = False
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
def forward(
self,
@@ -761,7 +471,7 @@ def forward(
if attention_mask is not None:
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- if self._use_flash_attention_2:
+ if self.config._attn_implementation == "flash_attention_2":
# make sure padded tokens output 0
hidden_states[~expand_attention_mask] = 0.0
# 2d mask is passed through the layers
@@ -854,6 +564,7 @@ class SEWPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ _supports_flex_attn = False # needs a proper look into the mask creation
def _init_weights(self, module):
"""Initialize the weights"""
@@ -915,6 +626,125 @@ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attenti
return attention_mask
+def _compute_mask_indices(
+ shape: Tuple[int, int],
+ mask_prob: float,
+ mask_length: int,
+ attention_mask: Optional[torch.LongTensor] = None,
+ min_masks: int = 0,
+) -> np.ndarray:
+ """
+ Computes random mask spans for a given shape. Used to implement [SpecAugment: A Simple Data Augmentation Method for
+ ASR](https://arxiv.org/abs/1904.08779). Note that this method is not optimized to run on TPU and should be run on
+ CPU as part of the preprocessing during training.
+
+ Args:
+ shape: The shape for which to compute masks. This should be of a tuple of size 2 where
+ the first element is the batch size and the second element is the length of the axis to span.
+ mask_prob: The percentage of the whole axis (between 0 and 1) which will be masked. The number of
+ independently generated mask spans of length `mask_length` is computed by
+ `mask_prob*shape[1]/mask_length`. Note that due to overlaps, `mask_prob` is an upper bound and the
+ actual percentage will be smaller.
+ mask_length: size of the mask
+ min_masks: minimum number of masked spans
+ attention_mask: A (right-padded) attention mask which independently shortens the feature axis of
+ each batch dimension.
+ """
+ batch_size, sequence_length = shape
+
+ if mask_length < 1:
+ raise ValueError("`mask_length` has to be bigger than 0.")
+
+ if mask_length > sequence_length:
+ raise ValueError(
+ f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length}"
+ f" and `sequence_length`: {sequence_length}`"
+ )
+
+ # epsilon is used for probabilistic rounding
+ epsilon = np.random.rand(1).item()
+
+ def compute_num_masked_span(input_length):
+ """Given input length, compute how many spans should be masked"""
+ num_masked_span = int(mask_prob * input_length / mask_length + epsilon)
+ num_masked_span = max(num_masked_span, min_masks)
+
+ # make sure num masked span <= sequence_length
+ if num_masked_span * mask_length > sequence_length:
+ num_masked_span = sequence_length // mask_length
+
+ # make sure num_masked span is also <= input_length - (mask_length - 1)
+ if input_length - (mask_length - 1) < num_masked_span:
+ num_masked_span = max(input_length - (mask_length - 1), 0)
+
+ return num_masked_span
+
+ # compute number of masked spans in batch
+ input_lengths = (
+ attention_mask.detach().sum(-1).tolist()
+ if attention_mask is not None
+ else [sequence_length for _ in range(batch_size)]
+ )
+
+ # SpecAugment mask to fill
+ spec_aug_mask = np.zeros((batch_size, sequence_length), dtype=bool)
+ spec_aug_mask_idxs = []
+
+ max_num_masked_span = compute_num_masked_span(sequence_length)
+
+ if max_num_masked_span == 0:
+ return spec_aug_mask
+
+ for input_length in input_lengths:
+ # compute num of masked spans for this input
+ num_masked_span = compute_num_masked_span(input_length)
+
+ # get random indices to mask
+ spec_aug_mask_idx = np.random.choice(
+ np.arange(input_length - (mask_length - 1)), num_masked_span, replace=False
+ )
+
+ # pick first sampled index that will serve as a dummy index to pad vector
+ # to ensure same dimension for all batches due to probabilistic rounding
+ # Picking first sample just pads those vectors twice.
+ if len(spec_aug_mask_idx) == 0:
+ # this case can only happen if `input_length` is strictly smaller then
+ # `sequence_length` in which case the last token has to be a padding
+ # token which we can use as a dummy mask id
+ dummy_mask_idx = sequence_length - 1
+ else:
+ dummy_mask_idx = spec_aug_mask_idx[0]
+
+ spec_aug_mask_idx = np.concatenate(
+ [spec_aug_mask_idx, np.ones(max_num_masked_span - num_masked_span, dtype=np.int32) * dummy_mask_idx]
+ )
+ spec_aug_mask_idxs.append(spec_aug_mask_idx)
+
+ spec_aug_mask_idxs = np.array(spec_aug_mask_idxs)
+
+ # expand masked indices to masked spans
+ spec_aug_mask_idxs = np.broadcast_to(
+ spec_aug_mask_idxs[:, :, None], (batch_size, max_num_masked_span, mask_length)
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs.reshape(batch_size, max_num_masked_span * mask_length)
+
+ # add offset to the starting indexes so that indexes now create a span
+ offsets = np.arange(mask_length)[None, None, :]
+ offsets = np.broadcast_to(offsets, (batch_size, max_num_masked_span, mask_length)).reshape(
+ batch_size, max_num_masked_span * mask_length
+ )
+ spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
+
+ # ensure that we cannot have indices larger than sequence_length
+ if spec_aug_mask_idxs.max() > sequence_length - 1:
+ spec_aug_mask_idxs[spec_aug_mask_idxs > sequence_length - 1] = sequence_length - 1
+
+ # scatter indices to mask
+ np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
+
+ return spec_aug_mask
+
+
@auto_docstring
class SEWModel(SEWPreTrainedModel):
def __init__(self, config: SEWConfig):
@@ -1038,12 +868,14 @@ def forward(
)
+_HIDDEN_STATES_START_POSITION = 1
+
+
@auto_docstring(
custom_intro="""
SEW Model with a `language modeling` head on top for Connectionist Temporal Classification (CTC).
"""
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForCTC with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW
class SEWForCTC(SEWPreTrainedModel):
def __init__(self, config, target_lang: Optional[str] = None):
r"""
@@ -1196,11 +1028,10 @@ def forward(
@auto_docstring(
custom_intro="""
- SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like SUPERB
- Keyword Spotting.
+ SEW Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
+ SUPERB Keyword Spotting.
"""
)
-# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification with Wav2Vec2->SEW, wav2vec2->sew, WAV2VEC2->SEW
class SEWForSequenceClassification(SEWPreTrainedModel):
def __init__(self, config):
super().__init__(config)
diff --git a/src/transformers/models/sew/modular_sew.py b/src/transformers/models/sew/modular_sew.py
new file mode 100644
index 000000000000..b3aa3e01b6cd
--- /dev/null
+++ b/src/transformers/models/sew/modular_sew.py
@@ -0,0 +1,469 @@
+# coding=utf-8
+# Copyright 2021 ASAPP Inc. and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch SEW model."""
+
+import math
+import warnings
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.utils.checkpoint
+from torch import nn
+
+from ...activations import ACT2FN
+from ...integrations.deepspeed import is_deepspeed_zero3_enabled
+from ...integrations.fsdp import is_fsdp_managed_module
+from ...modeling_outputs import BaseModelOutput
+from ...modeling_utils import PreTrainedModel
+from ...utils import auto_docstring
+from ..wav2vec2.modeling_wav2vec2 import (
+ Wav2Vec2Attention,
+ Wav2Vec2EncoderLayer,
+ Wav2Vec2FeatureEncoder,
+ Wav2Vec2FeedForward,
+ Wav2Vec2ForCTC,
+ Wav2Vec2ForSequenceClassification,
+ Wav2Vec2GroupNormConvLayer,
+ Wav2Vec2LayerNormConvLayer,
+ Wav2Vec2NoLayerNormConvLayer,
+ Wav2Vec2SamePadLayer,
+ _compute_mask_indices,
+)
+from .configuration_sew import SEWConfig
+
+
+_HIDDEN_STATES_START_POSITION = 1
+
+
+class SEWNoLayerNormConvLayer(Wav2Vec2NoLayerNormConvLayer):
+ pass
+
+
+class SEWLayerNormConvLayer(Wav2Vec2LayerNormConvLayer):
+ pass
+
+
+class SEWGroupNormConvLayer(Wav2Vec2GroupNormConvLayer):
+ pass
+
+
+class SEWPositionalConvEmbedding(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.conv = nn.Conv1d(
+ config.hidden_size,
+ config.hidden_size,
+ kernel_size=config.num_conv_pos_embeddings,
+ padding=config.num_conv_pos_embeddings // 2,
+ groups=config.num_conv_pos_embedding_groups,
+ stride=config.squeeze_factor,
+ )
+
+ weight_norm = nn.utils.weight_norm
+ if hasattr(nn.utils.parametrizations, "weight_norm"):
+ weight_norm = nn.utils.parametrizations.weight_norm
+
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ with deepspeed.zero.GatheredParameters(self.conv.weight, modifier_rank=0):
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
+ if hasattr(self.conv, "parametrizations"):
+ weight_g = self.conv.parametrizations.weight.original0
+ weight_v = self.conv.parametrizations.weight.original1
+ else:
+ weight_g = self.conv.weight_g
+ weight_v = self.conv.weight_v
+ deepspeed.zero.register_external_parameter(self, weight_v)
+ deepspeed.zero.register_external_parameter(self, weight_g)
+ else:
+ self.conv = weight_norm(self.conv, name="weight", dim=2)
+
+ self.padding = SEWSamePadLayer(config.num_conv_pos_embeddings)
+ self.activation = ACT2FN[config.feat_extract_activation]
+
+ def forward(self, hidden_states):
+ hidden_states = self.conv(hidden_states)
+ hidden_states = self.padding(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ return hidden_states
+
+
+class SEWSamePadLayer(Wav2Vec2SamePadLayer):
+ pass
+
+
+class SEWUpsampling(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.projection = nn.Linear(config.hidden_size, config.hidden_size * config.squeeze_factor)
+ self.activation = ACT2FN[config.feat_extract_activation]
+ self.squeeze_factor = config.squeeze_factor
+
+ def forward(self, hidden_states):
+ hidden_states = self.projection(hidden_states)
+ hidden_states = self.activation(hidden_states)
+
+ if self.squeeze_factor > 1:
+ # transform embedding channels to sequence length
+ bsz, src_len, src_embed_dim = hidden_states.size()
+ tgt_len = src_len * self.squeeze_factor
+ tgt_embed_dim = src_embed_dim // self.squeeze_factor
+ hidden_states = hidden_states.reshape(bsz, src_len, self.squeeze_factor, tgt_embed_dim)
+ hidden_states = hidden_states.reshape(bsz, tgt_len, tgt_embed_dim)
+
+ return hidden_states
+
+
+class SEWFeatureEncoder(Wav2Vec2FeatureEncoder):
+ pass
+
+
+class SEWFeatureExtractor(SEWFeatureEncoder):
+ def __init__(self, config):
+ super().__init__(config)
+ warnings.warn(
+ f"The class `{self.__class__.__name__}` has been depreciated "
+ "and will be removed in Transformers v5. "
+ f"Use `{self.__class__.__bases__[0].__name__}` instead.",
+ FutureWarning,
+ )
+
+
+class SEWAttention(Wav2Vec2Attention):
+ pass
+
+
+class SEWFeedForward(Wav2Vec2FeedForward):
+ pass
+
+
+class SEWEncoderLayer(Wav2Vec2EncoderLayer):
+ pass
+
+
+class SEWEncoder(nn.Module):
+ def __init__(self, config):
+ super().__init__()
+ self.config = config
+ self.pos_conv_embed = SEWPositionalConvEmbedding(config)
+ self.pool = nn.AvgPool1d(config.squeeze_factor, config.squeeze_factor)
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
+ self.dropout = nn.Dropout(config.hidden_dropout)
+ self.layers = nn.ModuleList([SEWEncoderLayer(config) for _ in range(config.num_hidden_layers)])
+ self.upsample = SEWUpsampling(config)
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ attention_mask=None,
+ output_attentions=False,
+ output_hidden_states=False,
+ return_dict=True,
+ ):
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attentions = () if output_attentions else None
+
+ if attention_mask is not None:
+ expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
+ if self.config._attn_implementation == "flash_attention_2":
+ # make sure padded tokens output 0
+ hidden_states[~expand_attention_mask] = 0.0
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ else:
+ # make sure padded tokens output 0
+ hidden_states[~expand_attention_mask] = 0.0
+ input_lengths = (attention_mask.long()).sum(-1)
+ # apply pooling formula to get real output_lengths
+ output_lengths = input_lengths // self.config.squeeze_factor
+ max_encoder_length = hidden_states.shape[1] // self.config.squeeze_factor
+ attention_ids = (
+ torch.arange(0, max_encoder_length, device=output_lengths.device)
+ .view(1, -1)
+ .expand(output_lengths.shape[0], -1)
+ )
+ attention_mask = (attention_ids < output_lengths.view(-1, 1)).long()
+
+ # extend attention_mask
+ attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
+ attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
+ attention_mask = attention_mask.expand(
+ attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
+ )
+
+ n_input_timesteps = hidden_states.shape[1]
+
+ hidden_states = hidden_states.transpose(1, 2)
+ position_embeddings = self.pos_conv_embed(hidden_states)
+ pooled_hidden_states = self.pool(hidden_states)
+ min_length = min(position_embeddings.size(-1), pooled_hidden_states.size(-1))
+ hidden_states = pooled_hidden_states[..., :min_length] + position_embeddings[..., :min_length]
+ hidden_states = hidden_states.transpose(1, 2)
+
+ hidden_states = self.layer_norm(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+
+ synced_gpus = is_deepspeed_zero3_enabled() or is_fsdp_managed_module(self)
+
+ for layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
+ dropout_probability = torch.rand([])
+
+ skip_the_layer = True if self.training and (dropout_probability < self.config.layerdrop) else False
+ if not skip_the_layer or synced_gpus:
+ # under fsdp or deepspeed zero3 all gpus must run in sync
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ layer.__call__,
+ hidden_states,
+ attention_mask,
+ output_attentions,
+ )
+ else:
+ layer_outputs = layer(
+ hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
+ )
+ hidden_states = layer_outputs[0]
+
+ if skip_the_layer:
+ layer_outputs = (None, None)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (layer_outputs[1],)
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ hidden_states = self.upsample(hidden_states)
+ if hidden_states.shape[1] < n_input_timesteps:
+ hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, n_input_timesteps - hidden_states.shape[1]))
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None)
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+@auto_docstring
+class SEWPreTrainedModel(PreTrainedModel):
+ config_class = SEWConfig
+ base_model_prefix = "sew"
+ main_input_name = "input_values"
+ supports_gradient_checkpointing = True
+ _supports_flash_attn_2 = True
+ _supports_sdpa = True
+ _supports_flex_attn = False # needs a proper look into the mask creation
+
+ def _init_weights(self, module):
+ """Initialize the weights"""
+ if isinstance(module, SEWPositionalConvEmbedding):
+ nn.init.normal_(
+ module.conv.weight,
+ mean=0,
+ std=2 * math.sqrt(1 / (module.conv.kernel_size[0] * module.conv.in_channels)),
+ )
+ nn.init.constant_(module.conv.bias, 0)
+ elif isinstance(module, nn.Linear):
+ # Slightly different from the TF version which uses truncated_normal for initialization
+ # cf https://github.com/pytorch/pytorch/pull/5617
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ elif isinstance(module, nn.Conv1d):
+ if is_deepspeed_zero3_enabled():
+ import deepspeed
+
+ if hasattr(module, "weight_v") and hasattr(module, "weight_g"):
+ with deepspeed.zero.GatheredParameters([module.weight_v, module.weight_g], modifier_rank=0):
+ nn.init.kaiming_normal_(module.weight.data)
+ else:
+ with deepspeed.zero.GatheredParameters(module.weight, modifier_rank=0):
+ nn.init.kaiming_normal_(module.weight.data)
+ else:
+ nn.init.kaiming_normal_(module.weight.data)
+
+ if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
+ module.bias.data.zero_()
+
+ def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
+ """
+ Computes the output length of the convolutional layers
+ """
+
+ def _conv_out_length(input_length, kernel_size, stride):
+ # 1D convolutional layer output length formula taken
+ # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
+ return torch.div(input_length - kernel_size, stride, rounding_mode="floor") + 1
+
+ for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
+ input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
+
+ return input_lengths
+
+ def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
+ output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
+ batch_size = attention_mask.shape[0]
+
+ attention_mask = torch.zeros(
+ (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
+ )
+ # these two operations makes sure that all values before the output lengths idxs are attended to
+ attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
+ attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
+ return attention_mask
+
+
+@auto_docstring
+class SEWModel(SEWPreTrainedModel):
+ def __init__(self, config: SEWConfig):
+ super().__init__(config)
+ self.config = config
+ self.feature_extractor = SEWFeatureEncoder(config)
+ self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
+
+ self.project_features = config.conv_dim[-1] != config.hidden_size
+ if self.project_features:
+ self.feature_projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
+ self.feature_dropout = nn.Dropout(config.feat_proj_dropout)
+
+ if config.mask_time_prob > 0.0 or config.mask_feature_prob > 0.0:
+ self.masked_spec_embed = nn.Parameter(torch.Tensor(config.hidden_size).uniform_())
+
+ self.encoder = SEWEncoder(config)
+
+ # Initialize weights and apply final processing
+ self.post_init()
+
+ # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
+ def _mask_hidden_states(
+ self,
+ hidden_states: torch.FloatTensor,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.LongTensor] = None,
+ ):
+ """
+ Masks extracted features along time axis and/or along feature axis according to
+ [SpecAugment](https://arxiv.org/abs/1904.08779).
+ """
+
+ # `config.apply_spec_augment` can set masking to False
+ if not getattr(self.config, "apply_spec_augment", True):
+ return hidden_states
+
+ # generate indices & apply SpecAugment along time axis
+ batch_size, sequence_length, hidden_size = hidden_states.size()
+
+ if mask_time_indices is not None:
+ # apply SpecAugment along time axis with given mask_time_indices
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+ elif self.config.mask_time_prob > 0 and self.training:
+ mask_time_indices = _compute_mask_indices(
+ (batch_size, sequence_length),
+ mask_prob=self.config.mask_time_prob,
+ mask_length=self.config.mask_time_length,
+ attention_mask=attention_mask,
+ min_masks=self.config.mask_time_min_masks,
+ )
+ mask_time_indices = torch.tensor(mask_time_indices, device=hidden_states.device, dtype=torch.bool)
+ hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
+
+ if self.config.mask_feature_prob > 0 and self.training:
+ # generate indices & apply SpecAugment along feature axis
+ mask_feature_indices = _compute_mask_indices(
+ (batch_size, hidden_size),
+ mask_prob=self.config.mask_feature_prob,
+ mask_length=self.config.mask_feature_length,
+ min_masks=self.config.mask_feature_min_masks,
+ )
+ mask_feature_indices = torch.tensor(mask_feature_indices, device=hidden_states.device, dtype=torch.bool)
+ mask_feature_indices = mask_feature_indices[:, None].expand(-1, sequence_length, -1)
+ hidden_states[mask_feature_indices] = 0
+
+ return hidden_states
+
+ @auto_docstring
+ def forward(
+ self,
+ input_values: Optional[torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ mask_time_indices: Optional[torch.FloatTensor] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, BaseModelOutput]:
+ r"""
+ mask_time_indices (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Indices to mask extracted features for contrastive loss. When in training mode, model learns to predict
+ masked extracted features in *config.proj_codevector_dim* space.
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ extract_features = self.feature_extractor(input_values)
+ extract_features = extract_features.transpose(1, 2)
+ extract_features = self.layer_norm(extract_features)
+
+ if self.project_features:
+ extract_features = self.feature_projection(extract_features)
+ hidden_states = self.feature_dropout(extract_features)
+
+ if attention_mask is not None:
+ # compute reduced attention_mask corresponding to feature vectors
+ attention_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
+
+ hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
+
+ encoder_outputs = self.encoder(
+ hidden_states,
+ attention_mask=attention_mask,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+
+ hidden_states = encoder_outputs[0]
+
+ if not return_dict:
+ return (hidden_states,) + encoder_outputs[1:]
+
+ return BaseModelOutput(
+ last_hidden_state=hidden_states,
+ hidden_states=encoder_outputs.hidden_states,
+ attentions=encoder_outputs.attentions,
+ )
+
+
+class SEWForCTC(Wav2Vec2ForCTC):
+ pass
+
+
+class SEWForSequenceClassification(Wav2Vec2ForSequenceClassification):
+ pass
+
+
+__all__ = ["SEWForCTC", "SEWForSequenceClassification", "SEWModel", "SEWPreTrainedModel"]
diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py
index 8db0674633fb..4acee66424bb 100755
--- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py
+++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py
@@ -15,7 +15,7 @@
"""PyTorch Speech2Text model."""
import math
-from typing import Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import torch
from torch import nn
@@ -23,21 +23,33 @@
from ...activations import ACT2FN
from ...generation import GenerationMixin
-from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
Seq2SeqLMOutput,
Seq2SeqModelOutput,
)
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...utils import (
auto_docstring,
+ is_torch_flex_attn_available,
logging,
)
from .configuration_speech_to_text import Speech2TextConfig
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
logger = logging.get_logger(__name__)
@@ -161,7 +173,38 @@ def create_position_ids_from_input_ids(
return incremental_indices.long() + padding_idx
-# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Speech2Text
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
+# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Attention with Wav2Vec2->Speech2Text
class Speech2TextAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -196,9 +239,6 @@ def __init__(
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -207,6 +247,9 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -214,10 +257,16 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
+
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -232,18 +281,18 @@ def forward(
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -255,72 +304,27 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
-
-
-SPEECH_TO_TEXT_ATTENTION_CLASSES = {"eager": Speech2TextAttention}
+ return attn_output, attn_weights, past_key_value
# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT
@@ -329,7 +333,7 @@ def __init__(self, config: Speech2TextConfig):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = Speech2TextAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -393,12 +397,13 @@ def forward(
# copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT
+# TODO: change copy when applying cache class
class Speech2TextDecoderLayer(nn.Module):
def __init__(self, config: Speech2TextConfig):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = Speech2TextAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -411,7 +416,7 @@ def __init__(self, config: Speech2TextConfig):
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = SPEECH_TO_TEXT_ATTENTION_CLASSES[config._attn_implementation](
+ self.encoder_attn = Speech2TextAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -423,6 +428,7 @@ def __init__(self, config: Speech2TextConfig):
self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
@@ -519,6 +525,12 @@ class Speech2TextPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
main_input_name = "input_features"
supports_gradient_checkpointing = True
+ # TODO: tests would need a rewrite to check for correct implementation
+ # Current tests always assume certain inputs to be passed
+ _supports_flash_attn_2 = False
+ _supports_sdpa = False
+ # Compile issues
+ _supports_flex_attn = False
def _init_weights(self, module):
std = self.config.init_std
@@ -655,10 +667,10 @@ def forward(
hidden_states = inputs_embeds + embed_pos
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- # expand attention_mask
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@@ -713,6 +725,29 @@ def forward(
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions
)
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
class Speech2TextDecoder(Speech2TextPreTrainedModel):
"""
@@ -857,16 +892,18 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
- attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
+ attention_mask = self._update_causal_mask(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
- )
# embed positions
positions = self.embed_positions(input_ids, past_key_values_length=past_key_values_length)
@@ -963,6 +1000,82 @@ def forward(
cross_attentions=all_cross_attentions,
)
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_shape),
+ device=inputs_embeds.device,
+ )
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ return attention_mask
+
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
@auto_docstring
class Speech2TextModel(Speech2TextPreTrainedModel):
diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
index 4936ae56366d..3bc19a75b3c3 100644
--- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
+++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py
@@ -15,7 +15,7 @@
# limitations under the License.
"""PyTorch Time Series Transformer model."""
-from typing import List, Optional, Tuple, Union
+from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import torch
@@ -23,7 +23,13 @@
from ...activations import ACT2FN
from ...cache_utils import Cache, EncoderDecoderCache
-from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -31,12 +37,17 @@
Seq2SeqTSModelOutput,
Seq2SeqTSPredictionOutput,
)
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput
-from ...utils import auto_docstring, logging
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from .configuration_time_series_transformer import TimeSeriesTransformerConfig
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
+
+
logger = logging.get_logger(__name__)
@@ -265,6 +276,37 @@ def forward(self, x):
return self.value_projection(x)
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->TimeSeriesTransformer
class TimeSeriesTransformerAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -317,17 +359,25 @@ def forward(
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
cache_position: Optional[torch.Tensor] = None,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
# if key_value_states are provided this layer is used as a cross-attention layer
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
+
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- query_states = query_states * self.scaling
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
if past_key_value is not None:
if isinstance(past_key_value, EncoderDecoderCache):
@@ -348,8 +398,8 @@ def forward(
else:
key_states = self.k_proj(current_states)
value_states = self.v_proj(current_states)
- key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
- value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2)
+ key_states = key_states.view(*kv_input_shape).transpose(1, 2)
+ value_states = value_states.view(*kv_input_shape).transpose(1, 2)
if past_key_value is not None:
# save all key/value_states to cache to be re-used for fast auto-regressive generation
@@ -361,66 +411,27 @@ def forward(
if is_cross_attention:
past_key_value.is_updated[self.layer_idx] = True
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = query_states.reshape(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- attention_mask = attention_mask[:, :, :, : key_states.shape[-2]]
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
+
+ attn_output, attn_weights = attention_interface(
+ self,
+ query_states,
+ key_states,
+ value_states,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
+ )
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, attn_weights_reshaped, past_key_value
+ return attn_output, attn_weights, past_key_value
# Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer, BART->TIME_SERIES_TRANSFORMER
@@ -429,7 +440,7 @@ def __init__(self, config: TimeSeriesTransformerConfig, layer_idx: Optional[int]
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = TimeSeriesTransformerAttention(
embed_dim=self.embed_dim,
num_heads=config.encoder_attention_heads,
dropout=config.attention_dropout,
@@ -495,19 +506,13 @@ def forward(
return outputs
-# TODO: Implement attention with SDPA for TimeSeriesTransformer.
-TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES = {
- "eager": TimeSeriesTransformerAttention,
-}
-
-
# Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER
class TimeSeriesTransformerDecoderLayer(nn.Module):
def __init__(self, config: TimeSeriesTransformerConfig, layer_idx: Optional[int] = None):
super().__init__()
self.embed_dim = config.d_model
- self.self_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
+ self.self_attn = TimeSeriesTransformerAttention(
embed_dim=self.embed_dim,
num_heads=config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -521,7 +526,7 @@ def __init__(self, config: TimeSeriesTransformerConfig, layer_idx: Optional[int]
self.activation_dropout = config.activation_dropout
self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
- self.encoder_attn = TIME_SERIES_TRANSFORMER_ATTENTION_CLASSES[config._attn_implementation](
+ self.encoder_attn = TimeSeriesTransformerAttention(
self.embed_dim,
config.decoder_attention_heads,
dropout=config.attention_dropout,
@@ -595,6 +600,7 @@ def forward(
layer_head_mask=cross_attn_layer_head_mask,
past_key_value=past_key_value,
output_attentions=output_attentions,
+ cache_position=cache_position,
)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
hidden_states = residual + hidden_states
@@ -626,6 +632,12 @@ class TimeSeriesTransformerPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
main_input_name = "past_values"
supports_gradient_checkpointing = True
+ # TODO: tests would need a rewrite to check for correct implementation
+ # Current tests always assume certain inputs to be passed
+ _supports_flash_attn_2 = False
+ _supports_sdpa = False
+ # Compile issues
+ _supports_flex_attn = False
def _init_weights(self, module):
std = self.config.init_std
@@ -640,6 +652,105 @@ def _init_weights(self, module):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_causal_mask
+ def _update_causal_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ past_key_values_length: int,
+ ):
+ if self.config._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask)
+ # Other attention flavors support in-built causal (when `mask is None`)
+ # while we need to create our specific block mask regardless
+ elif attention_mask is None:
+ attention_mask = make_flex_block_causal_mask(
+ torch.ones(
+ size=(input_shape),
+ device=inputs_embeds.device,
+ )
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask, input_shape, inputs_embeds, past_key_values_length
+ )
+
+ return attention_mask
+
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoder._update_cross_attn_mask
+ def _update_cross_attn_mask(
+ self,
+ encoder_hidden_states: Union[torch.Tensor, None],
+ encoder_attention_mask: Union[torch.Tensor, None],
+ input_shape: torch.Size,
+ inputs_embeds: torch.Tensor,
+ ):
+ # expand encoder attention mask
+ if encoder_hidden_states is not None and encoder_attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ encoder_attention_mask = encoder_attention_mask if 0 in encoder_attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa(
+ encoder_attention_mask,
+ inputs_embeds.dtype,
+ tgt_len=input_shape[-1],
+ )
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(encoder_attention_mask, torch.Tensor):
+ encoder_attention_mask = make_flex_block_causal_mask(
+ encoder_attention_mask,
+ query_length=input_shape[-1],
+ is_causal=False,
+ )
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ encoder_attention_mask = _prepare_4d_attention_mask(
+ encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
+ )
+
+ return encoder_attention_mask
+
class TimeSeriesTransformerEncoder(TimeSeriesTransformerPreTrainedModel):
"""
@@ -718,10 +829,10 @@ def forward(
hidden_states = self.layernorm_embedding(hidden_states + embed_pos)
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
- # expand attention_mask
- if attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ inputs_embeds,
+ )
encoder_states = () if output_hidden_states else None
all_attentions = () if output_attentions else None
@@ -907,16 +1018,18 @@ def forward(
past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device
)
- attention_mask = _prepare_4d_causal_attention_mask(
- attention_mask, input_shape, inputs_embeds, past_key_values_length
+ attention_mask = self._update_causal_mask(
+ attention_mask,
+ input_shape,
+ inputs_embeds,
+ past_key_values_length,
+ )
+ encoder_attention_mask = self._update_cross_attn_mask(
+ encoder_hidden_states,
+ encoder_attention_mask,
+ input_shape,
+ inputs_embeds,
)
-
- # expand encoder attention mask
- if encoder_hidden_states is not None and encoder_attention_mask is not None:
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- encoder_attention_mask = _prepare_4d_attention_mask(
- encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]
- )
hidden_states = self.value_embedding(inputs_embeds)
embed_pos = self.embed_positions(inputs_embeds.size(), past_key_values_length=self.config.context_length)
@@ -966,6 +1079,7 @@ def forward(
None,
output_attentions,
use_cache,
+ cache_position,
)
else:
layer_outputs = decoder_layer(
@@ -1333,7 +1447,16 @@ def forward(
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
)
- dec_input = transformer_inputs[:, self.config.context_length :, ...]
+ # Avoid empty tensors and instead create a zeroes tensor which
+ # will be treated the same in torch, i.e. matmul with empty == all 0s
+ if self.config.context_length >= transformer_inputs.shape[1]:
+ bsz, _, dim = transformer_inputs.shape
+ dec_input = torch.zeros(
+ size=(bsz, 1, dim), device=transformer_inputs.device, dtype=transformer_inputs.dtype
+ )
+ else:
+ dec_input = transformer_inputs[:, self.config.context_length :, ...]
+
decoder_outputs = self.decoder(
inputs_embeds=dec_input,
attention_mask=decoder_attention_mask,
diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py
index 071db33fdcd4..07ee6608b7ee 100755
--- a/src/transformers/models/unispeech/modeling_unispeech.py
+++ b/src/transformers/models/unispeech/modeling_unispeech.py
@@ -4,10 +4,25 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_unispeech.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
import warnings
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
@@ -17,7 +32,8 @@
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
-from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
@@ -25,13 +41,14 @@
SequenceClassifierOutput,
Wav2Vec2BaseModelOutput,
)
-from ...modeling_utils import PreTrainedModel
-from ...utils import auto_docstring, logging
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from .configuration_unispeech import UniSpeechConfig
-if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -263,6 +280,36 @@ def forward(self, hidden_states):
return hidden_states, norm_hidden_states
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
class UniSpeechAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -297,9 +344,6 @@ def __init__(
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -308,6 +352,9 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -315,267 +362,16 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped, past_key_value
-
-
-class UniSpeechFlashAttention2(UniSpeechAttention):
- """
- UniSpeech flash attention module. This module inherits from `UniSpeechAttention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
-
- def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, q_len, _ = hidden_states.size()
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0].transpose(1, 2)
- value_states = past_key_value[1].transpose(1, 2)
- elif is_cross_attention:
- # cross_attentions
- key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
- value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
- else:
- # self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (LlamaRMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=self.dropout if self.training else 0.0,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1)
- attn_output = self.out_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class UniSpeechSdpaAttention(UniSpeechAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "UniSpeechModel is using UniSpeechSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -590,18 +386,18 @@ def forward(
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -613,39 +409,27 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
- query_states = self._shape(query_states, tgt_len, bsz)
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
- is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
- attn_output = torch.nn.functional.scaled_dot_product_attention(
+ attn_output, attn_weights = attention_interface(
+ self,
query_states,
key_states,
value_states,
- attn_mask=attention_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
)
- if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, None, past_key_value
+ return attn_output, attn_weights, past_key_value
class UniSpeechFeedForward(nn.Module):
@@ -672,21 +456,15 @@ def forward(self, hidden_states):
return hidden_states
-UNISPEECH_ATTENTION_CLASSES = {
- "eager": UniSpeechAttention,
- "sdpa": UniSpeechSdpaAttention,
- "flash_attention_2": UniSpeechFlashAttention2,
-}
-
-
class UniSpeechEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
- self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation](
+ self.attention = UniSpeechAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
+ config=config,
)
self.dropout = nn.Dropout(config.hidden_dropout)
@@ -723,7 +501,6 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([UniSpeechEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
def forward(
self,
@@ -740,16 +517,11 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- else:
- # extend attention_mask
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
- attention_mask = attention_mask.expand(
- attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
- )
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@@ -798,6 +570,28 @@ def forward(
attentions=all_self_attentions,
)
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
class UniSpeechAttnAdapterLayer(nn.Module):
def __init__(self, config):
@@ -827,11 +621,12 @@ def forward(self, hidden_states: torch.FloatTensor):
class UniSpeechEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
- self.attention = UNISPEECH_ATTENTION_CLASSES[config._attn_implementation](
+ self.attention = UniSpeechAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
+ config=config,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -880,7 +675,6 @@ def __init__(self, config):
[UniSpeechEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
def forward(
self,
@@ -894,19 +688,14 @@ def forward(
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
- # make sure padded tokens are not attended to
+ # make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype)
- if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- else:
- # extend attention_mask
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
- attention_mask = attention_mask.expand(
- attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
- )
+ hidden_states[~expand_attention_mask] = 0
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@@ -957,6 +746,28 @@ def forward(
attentions=all_self_attentions,
)
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
class UniSpeechGumbelVectorQuantizer(nn.Module):
"""
@@ -1036,6 +847,8 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/unispeech/modular_unispeech.py b/src/transformers/models/unispeech/modular_unispeech.py
index 9c1eae48e851..795ab8596730 100644
--- a/src/transformers/models/unispeech/modular_unispeech.py
+++ b/src/transformers/models/unispeech/modular_unispeech.py
@@ -1,3 +1,19 @@
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch UniSpeech model."""
+
import math
import warnings
from dataclasses import dataclass
@@ -135,6 +151,8 @@ class UniSpeechPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
index 8e67b41a77f3..8d9ac9c33fcc 100755
--- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
+++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
@@ -4,10 +4,25 @@
# the file from the modular. If any change should be done, please apply the change to the
# modular_unispeech_sat.py file directly. One of our CI enforces this.
# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
import math
import warnings
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
@@ -17,7 +32,8 @@
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
-from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
@@ -27,13 +43,14 @@
Wav2Vec2BaseModelOutput,
XVectorOutput,
)
-from ...modeling_utils import PreTrainedModel
-from ...utils import auto_docstring, is_peft_available, logging
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
+from ...utils import auto_docstring, is_peft_available, is_torch_flex_attn_available, logging
from .configuration_unispeech_sat import UniSpeechSatConfig
-if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -266,6 +283,36 @@ def forward(self, hidden_states):
return hidden_states, norm_hidden_states
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
class UniSpeechSatAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -300,9 +347,6 @@ def __init__(
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -311,6 +355,9 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -318,267 +365,16 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
-
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped, past_key_value
-
-
-class UniSpeechSatFlashAttention2(UniSpeechSatAttention):
- """
- UniSpeechSat flash attention module. This module inherits from `UniSpeechSatAttention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
-
- def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, q_len, _ = hidden_states.size()
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0].transpose(1, 2)
- value_states = past_key_value[1].transpose(1, 2)
- elif is_cross_attention:
- # cross_attentions
- key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
- value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
- else:
- # self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (LlamaRMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=self.dropout if self.training else 0.0,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1)
- attn_output = self.out_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class UniSpeechSatSdpaAttention(UniSpeechSatAttention):
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "UniSpeechSatModel is using UniSpeechSatSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -593,18 +389,18 @@ def forward(
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -616,39 +412,27 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
- query_states = self._shape(query_states, tgt_len, bsz)
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
- is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
- attn_output = torch.nn.functional.scaled_dot_product_attention(
+ attn_output, attn_weights = attention_interface(
+ self,
query_states,
key_states,
value_states,
- attn_mask=attention_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
)
- if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, None, past_key_value
+ return attn_output, attn_weights, past_key_value
class UniSpeechSatFeedForward(nn.Module):
@@ -675,21 +459,15 @@ def forward(self, hidden_states):
return hidden_states
-UNISPEECH_SAT_ATTENTION_CLASSES = {
- "eager": UniSpeechSatAttention,
- "sdpa": UniSpeechSatSdpaAttention,
- "flash_attention_2": UniSpeechSatFlashAttention2,
-}
-
-
class UniSpeechSatEncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
- self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation](
+ self.attention = UniSpeechSatAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
+ config=config,
)
self.dropout = nn.Dropout(config.hidden_dropout)
@@ -726,7 +504,6 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([UniSpeechSatEncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
def forward(
self,
@@ -743,16 +520,11 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- else:
- # extend attention_mask
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
- attention_mask = attention_mask.expand(
- attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
- )
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@@ -801,6 +573,28 @@ def forward(
attentions=all_self_attentions,
)
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
class UniSpeechSatAttnAdapterLayer(nn.Module):
def __init__(self, config):
@@ -830,11 +624,12 @@ def forward(self, hidden_states: torch.FloatTensor):
class UniSpeechSatEncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
- self.attention = UNISPEECH_SAT_ATTENTION_CLASSES[config._attn_implementation](
+ self.attention = UniSpeechSatAttention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
+ config=config,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -883,7 +678,6 @@ def __init__(self, config):
[UniSpeechSatEncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
def forward(
self,
@@ -897,19 +691,14 @@ def forward(
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
- # make sure padded tokens are not attended to
+ # make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype)
- if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- else:
- # extend attention_mask
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
- attention_mask = attention_mask.expand(
- attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
- )
+ hidden_states[~expand_attention_mask] = 0
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@@ -960,6 +749,28 @@ def forward(
attentions=all_self_attentions,
)
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
class UniSpeechSatGumbelVectorQuantizer(nn.Module):
"""
@@ -1039,6 +850,8 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py
index 22de13c3bdc0..9f9e7d4f3c52 100644
--- a/src/transformers/models/unispeech_sat/modular_unispeech_sat.py
+++ b/src/transformers/models/unispeech_sat/modular_unispeech_sat.py
@@ -1,3 +1,19 @@
+# coding=utf-8
+# Copyright 2021 The Fairseq Authors and the HuggingFace Inc. team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""PyTorch UniSpeechSat model."""
+
import math
import warnings
from dataclasses import dataclass
@@ -145,6 +161,8 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
index 7eb38da9acbe..fb01234e3fed 100755
--- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py
+++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py
@@ -17,7 +17,7 @@
import math
import warnings
from dataclasses import dataclass
-from typing import Optional, Tuple, Union
+from typing import Callable, Optional, Tuple, Union
import numpy as np
import torch
@@ -28,7 +28,11 @@
from ...activations import ACT2FN
from ...integrations.deepspeed import is_deepspeed_zero3_enabled
from ...integrations.fsdp import is_fsdp_managed_module
-from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_attn_mask_utils import (
+ _prepare_4d_attention_mask,
+ _prepare_4d_attention_mask_for_sdpa,
+)
+from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_outputs import (
BaseModelOutput,
CausalLMOutput,
@@ -38,7 +42,8 @@
Wav2Vec2BaseModelOutput,
XVectorOutput,
)
-from ...modeling_utils import PreTrainedModel
+from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
+from ...processing_utils import Unpack
from ...utils import (
ModelOutput,
auto_docstring,
@@ -46,6 +51,7 @@
check_torch_load_is_safe,
is_peft_available,
is_safetensors_available,
+ is_torch_flex_attn_available,
logging,
)
from .configuration_wav2vec2 import Wav2Vec2Config
@@ -58,8 +64,8 @@
from safetensors.torch import load_file as safe_load_file
-if is_flash_attn_available():
- from ...modeling_flash_attention_utils import _flash_attention_forward
+if is_torch_flex_attn_available():
+ from ...integrations.flex_attention import make_flex_block_causal_mask
logger = logging.get_logger(__name__)
@@ -465,7 +471,37 @@ def forward(self, hidden_states):
return hidden_states, norm_hidden_states
-# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Wav2Vec2
+# Copied from transformers.models.bart.modeling_bart.eager_attention_forward
+def eager_attention_forward(
+ module: nn.Module,
+ query: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ attention_mask: Optional[torch.Tensor],
+ scaling: Optional[float] = None,
+ dropout: float = 0.0,
+ head_mask: Optional[torch.Tensor] = None,
+ **kwargs,
+):
+ if scaling is None:
+ scaling = query.size(-1) ** -0.5
+
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
+ if attention_mask is not None:
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
+
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
+
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
+ attn_output = torch.matmul(attn_weights, value)
+ attn_output = attn_output.transpose(1, 2).contiguous()
+
+ return attn_output, attn_weights
+
+
class Wav2Vec2Attention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
@@ -500,9 +536,6 @@ def __init__(
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
- def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
-
def forward(
self,
hidden_states: torch.Tensor,
@@ -511,6 +544,9 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
+ # TODO: we need a refactor so that the different attention modules can get their specific kwargs
+ # ATM, we have mixed things encoder, decoder, and encoder-decoder attn
+ **kwargs: Unpack[FlashAttentionKwargs],
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""
@@ -518,269 +554,16 @@ def forward(
# for the decoder
is_cross_attention = key_value_states is not None
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states) * self.scaling
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0]
- value_states = past_key_value[1]
- elif is_cross_attention:
- # cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
- else:
- # self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states, value_states)
-
- proj_shape = (bsz * self.num_heads, -1, self.head_dim)
- query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape)
- key_states = key_states.reshape(*proj_shape)
- value_states = value_states.reshape(*proj_shape)
-
- src_len = key_states.size(1)
- attn_weights = torch.bmm(query_states, key_states.transpose(1, 2))
+ # determine input shapes
+ bsz, tgt_len = hidden_states.shape[:-1]
+ src_len = key_value_states.shape[1] if is_cross_attention else tgt_len
- if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len):
- raise ValueError(
- f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is"
- f" {attn_weights.size()}"
- )
-
- if attention_mask is not None:
- if attention_mask.size() != (bsz, 1, tgt_len, src_len):
- raise ValueError(
- f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}"
- )
- attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- attn_weights = nn.functional.softmax(attn_weights, dim=-1)
-
- if layer_head_mask is not None:
- if layer_head_mask.size() != (self.num_heads,):
- raise ValueError(
- f"Head mask for a single layer should be of size {(self.num_heads,)}, but is"
- f" {layer_head_mask.size()}"
- )
- attn_weights = layer_head_mask.view(1, -1, 1, 1) * attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
-
- if output_attentions:
- # this operation is a bit awkward, but it's required to
- # make sure that attn_weights keeps its gradient.
- # In order to do so, attn_weights have to be reshaped
- # twice and have to be reused in the following
- attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
- attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len)
- else:
- attn_weights_reshaped = None
-
- attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
-
- attn_output = torch.bmm(attn_probs, value_states)
-
- if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz * self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim)
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
- attn_output = self.out_proj(attn_output)
-
- return attn_output, attn_weights_reshaped, past_key_value
-
-
-# Copied from transformers.models.hubert.modeling_hubert.HubertFlashAttention2 with Hubert->Wav2Vec2
-class Wav2Vec2FlashAttention2(Wav2Vec2Attention):
- """
- Wav2Vec2 flash attention module. This module inherits from `Wav2Vec2Attention` as the weights of the module stays
- untouched. The only required change would be on the forward pass where it needs to correctly call the public API of
- flash attention and deal with padding tokens in case the input contains any of them.
- """
-
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1.
- # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0.
- # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left).
- self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask()
-
- def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
- return tensor.view(bsz, seq_len, self.num_heads, self.head_dim)
-
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
-
- bsz, q_len, _ = hidden_states.size()
+ q_input_shape = (bsz, tgt_len, -1, self.head_dim)
+ kv_input_shape = (bsz, src_len, -1, self.head_dim)
# get query proj
- query_states = self._reshape(self.q_proj(hidden_states), -1, bsz)
- # get key, value proj
- # `past_key_value[0].shape[2] == key_value_states.shape[1]`
- # is checking that the `sequence_length` of the `past_key_value` is the same as
- # the provided `key_value_states` to support prefix tuning
- if (
- is_cross_attention
- and past_key_value is not None
- and past_key_value[0].shape[2] == key_value_states.shape[1]
- ):
- # reuse k,v, cross_attentions
- key_states = past_key_value[0].transpose(1, 2)
- value_states = past_key_value[1].transpose(1, 2)
- elif is_cross_attention:
- # cross_attentions
- key_states = self._reshape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._reshape(self.v_proj(key_value_states), -1, bsz)
- elif past_key_value is not None:
- # reuse k, v, self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
- key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1)
- value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1)
- else:
- # self_attention
- key_states = self._reshape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._reshape(self.v_proj(hidden_states), -1, bsz)
-
- if self.is_decoder:
- # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
- # Further calls to cross_attention layer can then reuse all cross-attention
- # key/value_states (first "if" case)
- # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of
- # all previous decoder key/value_states. Further calls to uni-directional self-attention
- # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case)
- # if encoder bi-directional self-attention `past_key_value` is always `None`
- past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2))
-
- kv_seq_len = key_states.shape[-2]
- if past_key_value is not None:
- kv_seq_len += past_key_value[0].shape[-2]
-
- # In PEFT, usually we cast the layer norms in float32 for training stability reasons
- # therefore the input hidden states gets silently casted in float32. Hence, we need
- # cast them back in the correct dtype just to be sure everything works as expected.
- # This might slowdown training & inference so it is recommended to not cast the LayerNorms
- # in fp32. (LlamaRMSNorm handles it correctly)
-
- input_dtype = query_states.dtype
- if input_dtype == torch.float32:
- if torch.is_autocast_enabled():
- target_dtype = torch.get_autocast_gpu_dtype()
- # Handle the case where the model is quantized
- elif hasattr(self.config, "_pre_quantization_dtype"):
- target_dtype = self.config._pre_quantization_dtype
- else:
- target_dtype = self.q_proj.weight.dtype
-
- logger.warning_once(
- f"The input hidden states seems to be silently casted in float32, this might be related to"
- f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
- f" {target_dtype}."
- )
-
- query_states = query_states.to(target_dtype)
- key_states = key_states.to(target_dtype)
- value_states = value_states.to(target_dtype)
-
- attn_output = _flash_attention_forward(
- query_states,
- key_states,
- value_states,
- attention_mask,
- q_len,
- dropout=self.dropout if self.training else 0.0,
- is_causal=self.is_causal,
- use_top_left_mask=self._flash_attn_uses_top_left_mask,
- )
-
- attn_output = attn_output.reshape(bsz, q_len, -1)
- attn_output = self.out_proj(attn_output)
-
- if not output_attentions:
- attn_weights = None
-
- return attn_output, attn_weights, past_key_value
-
-
-class Wav2Vec2SdpaAttention(Wav2Vec2Attention):
- # Copied from transformers.models.hubert.modeling_hubert.HubertSdpaAttention.forward with Hubert->Wav2Vec2
- def forward(
- self,
- hidden_states: torch.Tensor,
- key_value_states: Optional[torch.Tensor] = None,
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- layer_head_mask: Optional[torch.Tensor] = None,
- output_attentions: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
- """Input shape: Batch x Time x Channel"""
- if output_attentions:
- # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once this is implemented.
- logger.warning_once(
- "Wav2Vec2Model is using Wav2Vec2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True` . Falling back to the manual attention"
- ' implementation, but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
- )
- return super().forward(
- hidden_states,
- key_value_states=key_value_states,
- past_key_value=past_key_value,
- attention_mask=attention_mask,
- output_attentions=output_attentions,
- )
-
- # if key_value_states are provided this layer is used as a cross-attention layer
- # for the decoder
- is_cross_attention = key_value_states is not None
+ query_states = self.q_proj(hidden_states).view(*q_input_shape).transpose(1, 2)
- bsz, tgt_len, _ = hidden_states.size()
-
- # get query proj
- query_states = self.q_proj(hidden_states)
# get key, value proj
# `past_key_value[0].shape[2] == key_value_states.shape[1]`
# is checking that the `sequence_length` of the `past_key_value` is the same as
@@ -795,18 +578,18 @@ def forward(
value_states = past_key_value[1]
elif is_cross_attention:
# cross_attentions
- key_states = self._shape(self.k_proj(key_value_states), -1, bsz)
- value_states = self._shape(self.v_proj(key_value_states), -1, bsz)
+ key_states = self.k_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(key_value_states).view(*kv_input_shape).transpose(1, 2)
elif past_key_value is not None:
# reuse k, v, self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
else:
# self_attention
- key_states = self._shape(self.k_proj(hidden_states), -1, bsz)
- value_states = self._shape(self.v_proj(hidden_states), -1, bsz)
+ key_states = self.k_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
+ value_states = self.v_proj(hidden_states).view(*kv_input_shape).transpose(1, 2)
if self.is_decoder:
# if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states.
@@ -818,46 +601,27 @@ def forward(
# if encoder bi-directional self-attention `past_key_value` is always `None`
past_key_value = (key_states, value_states)
- query_states = self._shape(query_states, tgt_len, bsz)
-
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
- # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
- # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1.
- is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False
+ attention_interface: Callable = eager_attention_forward
+ if self.config._attn_implementation != "eager":
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
- # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask,
- # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577
- attn_output = torch.nn.functional.scaled_dot_product_attention(
+ attn_output, attn_weights = attention_interface(
+ self,
query_states,
key_states,
value_states,
- attn_mask=attention_mask,
- dropout_p=self.dropout if self.training else 0.0,
- is_causal=is_causal,
+ attention_mask,
+ dropout=0.0 if not self.training else self.dropout,
+ scaling=self.scaling,
+ output_attentions=output_attentions,
+ head_mask=layer_head_mask,
+ **kwargs,
)
- if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim):
- raise ValueError(
- f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is"
- f" {attn_output.size()}"
- )
-
- attn_output = attn_output.transpose(1, 2)
-
- # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be
- # partitioned across GPUs when using tensor-parallelism.
- attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim)
-
+ attn_output = attn_output.reshape(bsz, tgt_len, -1).contiguous()
attn_output = self.out_proj(attn_output)
- return attn_output, None, past_key_value
-
-
-WAV2VEC2_ATTENTION_CLASSES = {
- "eager": Wav2Vec2Attention,
- "sdpa": Wav2Vec2SdpaAttention,
- "flash_attention_2": Wav2Vec2FlashAttention2,
-}
+ return attn_output, attn_weights, past_key_value
class Wav2Vec2FeedForward(nn.Module):
@@ -887,11 +651,12 @@ def forward(self, hidden_states):
class Wav2Vec2EncoderLayer(nn.Module):
def __init__(self, config):
super().__init__()
- self.attention = WAV2VEC2_ATTENTION_CLASSES[config._attn_implementation](
+ self.attention = Wav2Vec2Attention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
+ config=config,
)
self.dropout = nn.Dropout(config.hidden_dropout)
@@ -922,11 +687,12 @@ def forward(self, hidden_states, attention_mask=None, output_attentions=False):
class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
def __init__(self, config):
super().__init__()
- self.attention = WAV2VEC2_ATTENTION_CLASSES[config._attn_implementation](
+ self.attention = Wav2Vec2Attention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.attention_dropout,
is_decoder=False,
+ config=config,
)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -973,7 +739,6 @@ def __init__(self, config):
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
def forward(
self,
@@ -990,16 +755,11 @@ def forward(
# make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
hidden_states[~expand_attention_mask] = 0
- if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- else:
- # extend attention_mask
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
- attention_mask = attention_mask.expand(
- attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
- )
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@@ -1048,6 +808,29 @@ def forward(
attentions=all_self_attentions,
)
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
class Wav2Vec2EncoderStableLayerNorm(nn.Module):
def __init__(self, config):
@@ -1060,7 +843,6 @@ def __init__(self, config):
[Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
)
self.gradient_checkpointing = False
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
def forward(
self,
@@ -1074,19 +856,14 @@ def forward(
all_self_attentions = () if output_attentions else None
if attention_mask is not None:
- # make sure padded tokens are not attended to
+ # make sure padded tokens output 0
expand_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[2])
- hidden_states = hidden_states * expand_attention_mask.to(dtype=hidden_states.dtype)
- if self._use_flash_attention_2:
- # 2d mask is passed through the layers
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
- else:
- # extend attention_mask
- attention_mask = 1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)
- attention_mask = attention_mask * torch.finfo(hidden_states.dtype).min
- attention_mask = attention_mask.expand(
- attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
- )
+ hidden_states[~expand_attention_mask] = 0
+
+ attention_mask = self._update_full_mask(
+ attention_mask,
+ hidden_states,
+ )
position_embeddings = self.pos_conv_embed(hidden_states)
hidden_states = hidden_states + position_embeddings
@@ -1137,6 +914,29 @@ def forward(
attentions=all_self_attentions,
)
+ # Copied from transformers.models.bart.modeling_bart.BartPreTrainedModel._update_full_mask
+ def _update_full_mask(
+ self,
+ attention_mask: Union[torch.Tensor, None],
+ inputs_embeds: torch.Tensor,
+ ):
+ if attention_mask is not None:
+ if self.config._attn_implementation == "flash_attention_2":
+ attention_mask = attention_mask if 0 in attention_mask else None
+ elif self.config._attn_implementation == "sdpa":
+ # output_attentions=True & head_mask can not be supported when using SDPA, fall back to
+ # the manual implementation that requires a 4D causal mask in all cases.
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask_for_sdpa(attention_mask, inputs_embeds.dtype)
+ elif self.config._attn_implementation == "flex_attention":
+ if isinstance(attention_mask, torch.Tensor):
+ attention_mask = make_flex_block_causal_mask(attention_mask, is_causal=False)
+ else:
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
+
+ return attention_mask
+
class Wav2Vec2GumbelVectorQuantizer(nn.Module):
"""
@@ -1296,6 +1096,8 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = True
_supports_sdpa = True
+ # Compile issues
+ _supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
index fc660d06ac9a..af02fa91ee1d 100644
--- a/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
+++ b/src/transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py
@@ -26,7 +26,11 @@
XVectorOutput,
)
from ...modeling_utils import PreTrainedModel
-from ...utils import ModelOutput, auto_docstring, is_peft_available
+from ...utils import (
+ ModelOutput,
+ auto_docstring,
+ is_peft_available,
+)
from .configuration_wav2vec2_conformer import Wav2Vec2ConformerConfig
diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py
index 90330c264525..11670ea7d21b 100755
--- a/src/transformers/models/wavlm/modeling_wavlm.py
+++ b/src/transformers/models/wavlm/modeling_wavlm.py
@@ -617,6 +617,7 @@ class WavLMPreTrainedModel(PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = False
_supports_sdpa = False
+ _supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/wavlm/modular_wavlm.py b/src/transformers/models/wavlm/modular_wavlm.py
index b74d9fc703e0..1ff9d5052c0a 100644
--- a/src/transformers/models/wavlm/modular_wavlm.py
+++ b/src/transformers/models/wavlm/modular_wavlm.py
@@ -527,6 +527,7 @@ class WavLMPreTrainedModel(PreTrainedModel, Wav2Vec2PreTrainedModel):
supports_gradient_checkpointing = True
_supports_flash_attn_2 = False
_supports_sdpa = False
+ _supports_flex_attn = False
def _init_weights(self, module):
"""Initialize the weights"""
diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py
index e42c7ce30605..ef847d60595f 100644
--- a/src/transformers/models/whisper/modeling_whisper.py
+++ b/src/transformers/models/whisper/modeling_whisper.py
@@ -27,7 +27,10 @@
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache
from ...generation import GenerationMixin
from ...modeling_attn_mask_utils import AttentionMaskConverter
-from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available
+from ...modeling_flash_attention_utils import (
+ flash_attn_supports_top_left_mask,
+ is_flash_attn_available,
+)
from ...modeling_outputs import (
BaseModelOutput,
BaseModelOutputWithPastAndCrossAttentions,
@@ -259,6 +262,9 @@ def __init__(
self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
+ def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
+ return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
+
def forward(
self,
hidden_states: torch.Tensor,
@@ -572,7 +578,8 @@ def forward(
}
-# Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
+# (BC Dep) Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->Whisper, MBART->WHISPER
+# TODO(vasqu): fix copies when enabling whisper attn interface
class WhisperEncoderLayer(nn.Module):
def __init__(self, config: WhisperConfig):
super().__init__()
diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py
index dc1e268cbe4a..e9f3f63c965a 100755
--- a/src/transformers/models/xglm/modeling_xglm.py
+++ b/src/transformers/models/xglm/modeling_xglm.py
@@ -282,7 +282,7 @@ def __init__(self, config: XGLMConfig):
self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim)
self.final_layer_norm = nn.LayerNorm(self.embed_dim)
- # copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward
+ # Copied from transformers.models.musicgen.modeling_musicgen.MusicgenDecoderLayer.forward
def forward(
self,
hidden_states: torch.Tensor,
diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py
index 783e7cb3a7a7..57037ee435c7 100644
--- a/tests/generation/test_utils.py
+++ b/tests/generation/test_utils.py
@@ -1148,6 +1148,10 @@ def test_assisted_decoding_matches_greedy_search(self, assistant_type):
# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
+ # force eager attention to support output attentions
+ if self.has_attentions:
+ config._attn_implementation = "eager"
+
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
@@ -1228,6 +1232,10 @@ def test_prompt_lookup_decoding_matches_greedy_search(self):
# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
+ # force eager attention to support output attentions
+ if self.has_attentions:
+ config._attn_implementation = "eager"
+
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
@@ -1282,6 +1290,10 @@ def test_dola_decoding_sample(self):
# enable cache if the model is not openai-gpt, xlnet, cpm, or xlm
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
+ # force eager attention to support output attentions
+ if self.has_attentions:
+ config._attn_implementation = "eager"
+
# Encoder-decoder models are not supported
if config.is_encoder_decoder:
self.skipTest("DoLa is not supported for encoder-decoder models")
@@ -1346,6 +1358,10 @@ def test_assisted_decoding_sample(self):
# enable cache
config, inputs_dict = self.prepare_config_and_inputs_for_generate(batch_size=1)
+ # force eager attention to support output attentions
+ if self.has_attentions:
+ config._attn_implementation = "eager"
+
# NOTE: assisted generation only works with cache on at the moment.
if not hasattr(config.get_text_config(), "use_cache"):
self.skipTest(reason=f"{model_class.__name__} doesn't support caching")
diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py
index 34fbcd80f9bf..ded8d5f0a8e8 100644
--- a/tests/models/bart/test_modeling_bart.py
+++ b/tests/models/bart/test_modeling_bart.py
@@ -1521,3 +1521,7 @@ def test_decoder_model_attn_mask_past(self):
@unittest.skip(reason="Decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
+
+ @unittest.skip(reason="Decoder cannot keep gradients")
+ def test_flex_attention_with_grads():
+ return
diff --git a/tests/models/blenderbot/test_modeling_blenderbot.py b/tests/models/blenderbot/test_modeling_blenderbot.py
index 83a5c73ff7b7..bec16cf5dc13 100644
--- a/tests/models/blenderbot/test_modeling_blenderbot.py
+++ b/tests/models/blenderbot/test_modeling_blenderbot.py
@@ -554,3 +554,7 @@ def test_decoder_model_attn_mask_past(self):
@unittest.skip(reason="decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
+
+ @unittest.skip(reason="Decoder cannot keep gradients")
+ def test_flex_attention_with_grads():
+ return
diff --git a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py
index 5e9163376a2c..8d75649d8cc1 100644
--- a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py
+++ b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py
@@ -563,3 +563,7 @@ def test_decoder_model_attn_mask_past(self):
@unittest.skip(reason="decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
+
+ @unittest.skip(reason="Decoder cannot keep gradients")
+ def test_flex_attention_with_grads():
+ return
diff --git a/tests/models/data2vec/test_modeling_data2vec_audio.py b/tests/models/data2vec/test_modeling_data2vec_audio.py
index 96e970beb6b9..5a8f410a70d7 100644
--- a/tests/models/data2vec/test_modeling_data2vec_audio.py
+++ b/tests/models/data2vec/test_modeling_data2vec_audio.py
@@ -420,6 +420,9 @@ def test_retain_grad_hidden_states_attentions(self):
config.output_hidden_states = True
config.output_attentions = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
diff --git a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py
index 6df93374137b..9aceea0359b0 100644
--- a/tests/models/encoder_decoder/test_modeling_encoder_decoder.py
+++ b/tests/models/encoder_decoder/test_modeling_encoder_decoder.py
@@ -412,6 +412,10 @@ def check_encoder_decoder_model_output_attentions(
labels,
**kwargs,
):
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ decoder_config._attn_implementation = "eager"
+
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
@@ -445,6 +449,10 @@ def check_encoder_decoder_model_output_attentions_from_config(
# config file. Contrarily to most models, changing the model's config won't work -- the defaults are loaded
# from the inner models' configurations.
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ decoder_config._attn_implementation = "eager"
+
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
diff --git a/tests/models/hubert/test_modeling_hubert.py b/tests/models/hubert/test_modeling_hubert.py
index c4e65cfa5c6b..de26f4c7a4e3 100644
--- a/tests/models/hubert/test_modeling_hubert.py
+++ b/tests/models/hubert/test_modeling_hubert.py
@@ -370,6 +370,9 @@ def test_retain_grad_hidden_states_attentions(self):
config.output_hidden_states = True
config.output_attentions = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
@@ -632,6 +635,9 @@ def test_retain_grad_hidden_states_attentions(self):
config.output_hidden_states = True
config.output_attentions = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py
index c264de826c20..ba6e8f9c25eb 100644
--- a/tests/models/m2m_100/test_modeling_m2m_100.py
+++ b/tests/models/m2m_100/test_modeling_m2m_100.py
@@ -357,7 +357,8 @@ def test_inference_no_head(self):
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = torch.tensor(
- [[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]], device=torch_device
+ [[[-0.7780, -0.1676, 0.1038], [-6.7556, -1.3992, 0.0567], [-7.5383, -0.5920, -0.2779]]],
+ device=torch_device,
)
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
@@ -374,7 +375,8 @@ def test_inference_head(self):
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = torch.tensor(
- [[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]], device=torch_device
+ [[[-1.0448, -1.0411, 3.7992], [-3.2191, -3.2386, -1.3451], [-3.6210, -3.5993, 0.4925]]],
+ device=torch_device,
)
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
@@ -426,7 +428,7 @@ def test_flash_attn_2_seq_to_seq_generation(self):
Overwriting the common test as the test is flaky on tiny models
"""
model = M2M100ForConditionalGeneration.from_pretrained(
- "facebook/m2m100_418M", torch_dtype=torch.float16, attn_implementation="flash_attention_2"
+ "facebook/m2m100_418M", attn_implementation="flash_attention_2", torch_dtype=torch.float16
).to(torch_device)
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en")
diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py
index a53627852f5d..ed42b1b29f00 100644
--- a/tests/models/marian/test_modeling_marian.py
+++ b/tests/models/marian/test_modeling_marian.py
@@ -850,3 +850,7 @@ def test_decoder_model_attn_mask_past(self):
@unittest.skip(reason="Decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
+
+ @unittest.skip(reason="Decoder cannot keep gradients")
+ def test_flex_attention_with_grads():
+ return
diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py
index 3b328cfe9e01..4ef22c3c30e0 100644
--- a/tests/models/mbart/test_modeling_mbart.py
+++ b/tests/models/mbart/test_modeling_mbart.py
@@ -735,3 +735,7 @@ def test_decoder_model_attn_mask_past(self):
@unittest.skip(reason="Decoder cannot retain gradients")
def test_retain_grad_hidden_states_attentions(self):
return
+
+ @unittest.skip(reason="Decoder cannot retain gradients")
+ def test_flex_attention_with_grads(self):
+ return
diff --git a/tests/models/musicgen/test_modeling_musicgen.py b/tests/models/musicgen/test_modeling_musicgen.py
index b14df8de0e4c..1a27192506f3 100644
--- a/tests/models/musicgen/test_modeling_musicgen.py
+++ b/tests/models/musicgen/test_modeling_musicgen.py
@@ -728,6 +728,9 @@ def check_musicgen_model_output_attentions_from_config(
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
for model_class in self.all_model_classes:
self.check_musicgen_model_output_attentions(model_class, config, **inputs_dict)
self.check_musicgen_model_output_attentions_from_config(model_class, config, **inputs_dict)
@@ -805,6 +808,9 @@ def test_retain_grad_hidden_states_attentions(self):
config.text_encoder.output_attentions = True
config.decoder.output_attentions = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
@@ -1036,30 +1042,7 @@ def test_flash_attn_2_inference_equivalence(self):
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
- if not self.has_attentions:
- self.skipTest(reason="Model architecture does not support attentions")
-
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
- ).to(torch_device)
-
- for _, module in model.named_modules():
- if "FlashAttention" in module.__class__.__name__:
- return
-
- self.assertTrue(False, "FlashAttention2 modules not found in model")
+ self.skipTest(reason="Musicgen doesn't use the MusicgenFlashAttention2 class method.")
@require_torch_sdpa
@require_torch_gpu
@@ -1234,18 +1217,6 @@ def test_sdpa_can_dispatch_composite_models(self):
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.config._attn_implementation == "eager")
- for name, submodule in model_eager.named_modules():
- if "SdpaAttention" in submodule.__class__.__name__:
- raise ValueError("The eager model should not have SDPA attention layers")
-
- has_sdpa = False
- for name, submodule in model_sdpa.named_modules():
- if "SdpaAttention" in submodule.__class__.__name__:
- has_sdpa = True
- break
- if not has_sdpa and model_sdpa.config.model_type != "falcon":
- raise ValueError("The SDPA model should have SDPA attention layers")
-
def test_requires_grad_with_frozen_encoders(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
@@ -1276,6 +1247,10 @@ def test_requires_grad_with_frozen_encoders(self):
def test_generation_tester_mixin_inheritance(self):
pass
+ @unittest.skip(reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5."))
+ def test_sdpa_can_compile_dynamic(self):
+ pass
+
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
"""Produces a series of 'bip bip' sounds at a given frequency."""
diff --git a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
index 72cbb990c9ae..abf2edd1ce3e 100644
--- a/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
+++ b/tests/models/musicgen_melody/test_modeling_musicgen_melody.py
@@ -731,6 +731,9 @@ def check_musicgen_melody_model_output_attentions_from_config(
def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
for model_class in self.all_model_classes:
self.check_musicgen_melody_model_output_attentions(model_class, config, **inputs_dict)
self.check_musicgen_melody_model_output_attentions_from_config(model_class, config, **inputs_dict)
@@ -807,6 +810,9 @@ def test_retain_grad_hidden_states_attentions(self):
config.text_encoder.output_attentions = True
config.decoder.output_attentions = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
@@ -1036,30 +1042,7 @@ def test_flash_attn_2_inference_equivalence(self):
@mark.flash_attn_test
@slow
def test_flash_attn_2_conversion(self):
- if not self.has_attentions:
- self.skipTest(reason="Model architecture does not support attentions")
-
- config, _ = self.model_tester.prepare_config_and_inputs_for_common()
-
- for model_class in self.all_model_classes:
- if not model_class._supports_flash_attn_2:
- self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")
-
- model = model_class(config)
-
- with tempfile.TemporaryDirectory() as tmpdirname:
- model.save_pretrained(tmpdirname)
- model = model_class.from_pretrained(
- tmpdirname,
- torch_dtype=torch.float16,
- attn_implementation={"decoder": "flash_attention_2", "audio_encoder": None, "text_encoder": None},
- ).to(torch_device)
-
- for _, module in model.named_modules():
- if "FlashAttention" in module.__class__.__name__:
- return
-
- self.assertTrue(False, "FlashAttention2 modules not found in model")
+ self.skipTest(reason="MusicgenMelody doesn't use the MusicgenMelodyFlashAttention2 class method.")
@require_torch_sdpa
@require_torch_gpu
@@ -1234,18 +1217,6 @@ def test_sdpa_can_dispatch_composite_models(self):
self.assertTrue(model_eager.decoder.config._attn_implementation == "eager")
self.assertTrue(model_eager.config._attn_implementation == "eager")
- for name, submodule in model_eager.named_modules():
- if "SdpaAttention" in submodule.__class__.__name__:
- raise ValueError("The eager model should not have SDPA attention layers")
-
- has_sdpa = False
- for name, submodule in model_sdpa.named_modules():
- if "SdpaAttention" in submodule.__class__.__name__:
- has_sdpa = True
- break
- if not has_sdpa and model_sdpa.config.model_type != "falcon":
- raise ValueError("The SDPA model should have SDPA attention layers")
-
def test_requires_grad_with_frozen_encoders(self):
config = self.model_tester.get_config()
for model_class in self.all_model_classes:
@@ -1276,6 +1247,10 @@ def test_requires_grad_with_frozen_encoders(self):
def test_generation_tester_mixin_inheritance(self):
pass
+ @unittest.skip(reason=("MusicGen has a set of composite models which might not have SDPA themselves, e.g. T5."))
+ def test_sdpa_can_compile_dynamic(self):
+ pass
+
# Copied from tests.models.musicgen.test_modeling_musicgen.get_bip_bip
def get_bip_bip(bip_duration=0.125, duration=0.5, sample_rate=32000):
diff --git a/tests/models/patchtsmixer/test_modeling_patchtsmixer.py b/tests/models/patchtsmixer/test_modeling_patchtsmixer.py
index ad00eab111f6..9b48a2e30740 100644
--- a/tests/models/patchtsmixer/test_modeling_patchtsmixer.py
+++ b/tests/models/patchtsmixer/test_modeling_patchtsmixer.py
@@ -480,7 +480,7 @@ def test_pretrain_head(self):
)
self.assertEqual(output.shape, expected_shape)
- expected_slice = torch.tensor([[[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],]],device=torch_device) # fmt: skip
+ expected_slice = torch.tensor([[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],],device=torch_device) # fmt: skip
torch.testing.assert_close(output[0, :7, :1, :1], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
def test_forecasting_head(self):
diff --git a/tests/models/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py
index 2b7a127d4828..af119c41d335 100644
--- a/tests/models/pegasus/test_modeling_pegasus.py
+++ b/tests/models/pegasus/test_modeling_pegasus.py
@@ -597,3 +597,7 @@ def test_decoder_model_attn_mask_past(self):
@unittest.skip(reason="Decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
+
+ @unittest.skip(reason="Decoder cannot keep gradients")
+ def test_flex_attention_with_grads():
+ return
diff --git a/tests/models/pegasus_x/test_modeling_pegasus_x.py b/tests/models/pegasus_x/test_modeling_pegasus_x.py
index 2ffdc4636756..a6bf913e4c2e 100644
--- a/tests/models/pegasus_x/test_modeling_pegasus_x.py
+++ b/tests/models/pegasus_x/test_modeling_pegasus_x.py
@@ -590,7 +590,7 @@ def test_inference_no_head(self):
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = torch.tensor(
- [[0.0702, -0.1552, 0.1192], [0.0836, -0.1848, 0.1304], [0.0673, -0.1686, 0.1045]], device=torch_device
+ [[[0.0702, -0.1552, 0.1192], [0.0836, -0.1848, 0.1304], [0.0673, -0.1686, 0.1045]]], device=torch_device
)
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
@@ -608,7 +608,8 @@ def test_inference_head(self):
self.assertEqual(output.shape, expected_shape)
# change to expected output here
expected_slice = torch.tensor(
- [[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]], device=torch_device
+ [[[0.0, 9.5705185, 1.5897303], [0.0, 9.833374, 1.5828674], [0.0, 10.429961, 1.5643371]]],
+ device=torch_device,
)
torch.testing.assert_close(output[:, :3, :3], expected_slice, rtol=TOLERANCE, atol=TOLERANCE)
@@ -635,8 +636,7 @@ def test_seq_to_seq_generation(self):
batch_input,
max_length=512,
padding="max_length",
- truncation_strategy="only_first",
- truncation=True,
+ truncation="only_first",
return_tensors="pt",
)
@@ -872,3 +872,7 @@ def test_decoder_model_attn_mask_past(self):
@unittest.skip(reason="Decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
+
+ @unittest.skip(reason="Decoder cannot keep gradients")
+ def test_flex_attention_with_grads():
+ return
diff --git a/tests/models/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py
index 569a63d435e5..179750582832 100644
--- a/tests/models/plbart/test_modeling_plbart.py
+++ b/tests/models/plbart/test_modeling_plbart.py
@@ -677,3 +677,7 @@ def test_decoder_model_attn_mask_past(self):
@unittest.skip(reason="Decoder cannot keep gradients")
def test_retain_grad_hidden_states_attentions(self):
return
+
+ @unittest.skip(reason="Decoder cannot keep gradients")
+ def test_flex_attention_with_grads():
+ return
diff --git a/tests/models/sew/test_modeling_sew.py b/tests/models/sew/test_modeling_sew.py
index 9a7430c64112..2cab21cf5c92 100644
--- a/tests/models/sew/test_modeling_sew.py
+++ b/tests/models/sew/test_modeling_sew.py
@@ -342,6 +342,9 @@ def test_retain_grad_hidden_states_attentions(self):
config.output_hidden_states = True
config.output_attentions = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
diff --git a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
index 7ac8d5631d20..28cdaf34473e 100644
--- a/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
+++ b/tests/models/speech_encoder_decoder/test_modeling_speech_encoder_decoder.py
@@ -300,6 +300,10 @@ def check_encoder_decoder_model_output_attentions(
input_features=None,
**kwargs,
):
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ decoder_config._attn_implementation = "eager"
+
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
diff --git a/tests/models/unispeech/test_modeling_unispeech.py b/tests/models/unispeech/test_modeling_unispeech.py
index 0f8e2448320b..ebc537a47886 100644
--- a/tests/models/unispeech/test_modeling_unispeech.py
+++ b/tests/models/unispeech/test_modeling_unispeech.py
@@ -383,6 +383,9 @@ def test_retain_grad_hidden_states_attentions(self):
config.output_hidden_states = True
config.output_attentions = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
diff --git a/tests/models/unispeech_sat/test_modeling_unispeech_sat.py b/tests/models/unispeech_sat/test_modeling_unispeech_sat.py
index dd8deafc227f..ec438dea96b4 100644
--- a/tests/models/unispeech_sat/test_modeling_unispeech_sat.py
+++ b/tests/models/unispeech_sat/test_modeling_unispeech_sat.py
@@ -423,6 +423,9 @@ def test_retain_grad_hidden_states_attentions(self):
config.output_hidden_states = True
config.output_attentions = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
@@ -632,6 +635,9 @@ def test_retain_grad_hidden_states_attentions(self):
config.output_hidden_states = True
config.output_attentions = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
diff --git a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
index bfd9aad2332a..ffd08297f147 100644
--- a/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
+++ b/tests/models/vision_encoder_decoder/test_modeling_vision_encoder_decoder.py
@@ -246,6 +246,10 @@ def check_encoder_decoder_model_output_attentions(
pixel_values=None,
**kwargs,
):
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ decoder_config._attn_implementation = "eager"
+
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
@@ -480,6 +484,10 @@ def check_encoder_decoder_model_output_attentions(
pixel_values=None,
**kwargs,
):
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ decoder_config._attn_implementation = "eager"
+
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
@@ -670,6 +678,10 @@ def check_encoder_decoder_model_output_attentions(
pixel_values=None,
**kwargs,
):
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ decoder_config._attn_implementation = "eager"
+
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
@@ -807,6 +819,10 @@ def check_encoder_decoder_model_output_attentions(
labels=None,
**kwargs,
):
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ decoder_config._attn_implementation = "eager"
+
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
@@ -929,6 +945,10 @@ def check_encoder_decoder_model_output_attentions(
labels=None,
**kwargs,
):
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ decoder_config._attn_implementation = "eager"
+
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
@@ -1047,6 +1067,10 @@ def check_encoder_decoder_model_output_attentions(
labels=None,
**kwargs,
):
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+ decoder_config._attn_implementation = "eager"
+
# make the decoder inputs a different shape from the encoder inputs to harden the test
decoder_input_ids = decoder_input_ids[:, :-1]
decoder_attention_mask = decoder_attention_mask[:, :-1]
diff --git a/tests/models/wav2vec2/test_modeling_wav2vec2.py b/tests/models/wav2vec2/test_modeling_wav2vec2.py
index c47d4855ccd9..9597d2e6ef25 100644
--- a/tests/models/wav2vec2/test_modeling_wav2vec2.py
+++ b/tests/models/wav2vec2/test_modeling_wav2vec2.py
@@ -570,6 +570,9 @@ def test_retain_grad_hidden_states_attentions(self):
config.output_hidden_states = True
config.output_attentions = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
@@ -917,6 +920,9 @@ def test_retain_grad_hidden_states_attentions(self):
config.output_hidden_states = True
config.output_attentions = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
+
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py
index 2f16a86d80b5..a85c9e7e6256 100755
--- a/tests/test_modeling_common.py
+++ b/tests/test_modeling_common.py
@@ -958,6 +958,8 @@ def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
+ # force eager attention to support output attentions
+ config._attn_implementation = "eager"
seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
@@ -1106,7 +1108,11 @@ def _create_and_check_torchscript(self, config, inputs_dict):
configs_no_init.torchscript = True
for model_class in self.all_model_classes:
for attn_implementation in ["eager", "sdpa"]:
- if attn_implementation == "sdpa" and (not model_class._supports_sdpa or not is_torch_sdpa_available()):
+ if (
+ attn_implementation == "sdpa"
+ and (not model_class._supports_sdpa or not is_torch_sdpa_available())
+ or config.output_attentions
+ ):
continue
configs_no_init._attn_implementation = attn_implementation
@@ -1708,6 +1714,10 @@ def test_retain_grad_hidden_states_attentions(self):
config.output_hidden_states = True
config.output_attentions = self.has_attentions
+ # force eager attention to support output attentions
+ if self.has_attentions:
+ config._attn_implementation = "eager"
+
# no need to test all models as different heads yield the same functionality
model_class = self.all_model_classes[0]
model = model_class(config)
@@ -4555,13 +4565,26 @@ def test_flex_attention_with_grads(self):
# TODO: raushan, fix for composite models after making VLMs support new attn API
if not model_class._supports_flex_attn or self._is_composite:
self.skipTest(reason="This model does not support flex attention")
+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config._attn_implementation = "flex_attention"
- model = model_class(config).to(device=torch_device, dtype=torch.float16)
+ # Flex Attention can not use dropout
+ if hasattr(config, "attention_dropout"):
+ config.attention_dropout = 0
+ if hasattr(config, "attention_probs_dropout_prob"):
+ config.attention_probs_dropout_prob = 0
+
+ model = model_class(config).to(device=torch_device)
self.assertTrue(model.config._attn_implementation == "flex_attention")
+ # Elaborate workaround for encoder-decoder models as some do not specify their main input
+ dummy_inputs = {model.main_input_name: inputs_dict[model.main_input_name].to(torch_device)}
+ if config.is_encoder_decoder:
+ dummy_inputs["decoder_input_ids"] = inputs_dict["decoder_input_ids"]
+ dummy_inputs["decoder_attention_mask"] = inputs_dict["decoder_attention_mask"]
+
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
- _ = model(inputs_dict["input_ids"].to(torch_device))
+ _ = model(**dummy_inputs)
def test_generation_tester_mixin_inheritance(self):
"""