diff --git a/src/transformers/generation/candidate_generator.py b/src/transformers/generation/candidate_generator.py index 3425a0234b42..bb9222030553 100644 --- a/src/transformers/generation/candidate_generator.py +++ b/src/transformers/generation/candidate_generator.py @@ -28,7 +28,7 @@ if is_sklearn_available(): from sklearn.metrics import roc_curve -from ..cache_utils import DynamicCache +from ..cache_utils import Cache from ..pytorch_utils import isin_mps_friendly from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor @@ -1183,7 +1183,9 @@ def get_candidates(self, input_ids: torch.LongTensor) -> Tuple[torch.LongTensor, def _crop_past_key_values(model, past_key_values, max_length): """Crops the past key values up to a certain maximum length.""" new_past = [] - if model.config.is_encoder_decoder: + if isinstance(past_key_values, Cache): + past_key_values.crop(max_length) + elif model.config.is_encoder_decoder: for idx in range(len(past_key_values)): new_past.append( ( @@ -1204,8 +1206,6 @@ def _crop_past_key_values(model, past_key_values, max_length): else: for idx in range(len(past_key_values)): past_key_values[idx] = past_key_values[idx][:, :, :max_length, :] - elif isinstance(past_key_values, DynamicCache): - past_key_values.crop(max_length) elif past_key_values is not None: for idx in range(len(past_key_values)): if past_key_values[idx] != ([], []): diff --git a/src/transformers/models/autoformer/modeling_autoformer.py b/src/transformers/models/autoformer/modeling_autoformer.py index 8cb92caa1a4b..9f7f1515a278 100644 --- a/src/transformers/models/autoformer/modeling_autoformer.py +++ b/src/transformers/models/autoformer/modeling_autoformer.py @@ -370,13 +370,16 @@ def _init_weight(self): self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesValueEmbedding with TimeSeries->Autoformer diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index d81978f89c7e..3fb6c8b1c4b4 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -25,12 +25,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( @@ -43,10 +43,21 @@ Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_bart import BartConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -81,15 +92,18 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor = None): """`input_ids' shape is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids.shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ).expand(bsz, -1) + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + else: + position_ids = position_ids.unsqueeze(0) - return super().forward(positions + self.offset) + return super().forward(position_ids + self.offset) class BartScaledWordEmbedding(nn.Embedding): @@ -117,6 +131,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[BartConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -133,73 +148,74 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -213,10 +229,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -278,73 +291,73 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # BartFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError( + "BartSdpaAttention2 attention does not support `output_attentions`. " + "Use the argument `attn_implementation='eager'` when loading the model." + ) + # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, q_len, _ = hidden_states.size() # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): @@ -379,10 +392,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, None, past_key_value class BartSdpaAttention(BartAttention): @@ -390,10 +400,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions: @@ -408,6 +419,7 @@ def forward( past_key_value=past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) # if key_value_states are provided this layer is used as a cross-attention layer @@ -417,50 +429,55 @@ def forward( bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + causal_mask = None + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 @@ -468,23 +485,16 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value @@ -498,7 +508,7 @@ def forward( class BartEncoderLayer(nn.Module): - def __init__(self, config: BartConfig): + def __init__(self, config: BartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -507,6 +517,7 @@ def __init__(self, config: BartConfig): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, + layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -568,7 +579,7 @@ def forward( class BartDecoderLayer(nn.Module): - def __init__(self, config: BartConfig): + def __init__(self, config: BartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -579,6 +590,7 @@ def __init__(self, config: BartConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -591,6 +603,7 @@ def __init__(self, config: BartConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -605,9 +618,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -626,47 +640,42 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -682,7 +691,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -721,6 +730,8 @@ class BartPreTrainedModel(PreTrainedModel): _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -732,6 +743,9 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() @property def dummy_inputs(self): @@ -743,6 +757,131 @@ def dummy_inputs(self): } return dummy_inputs + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + class PretrainedBartModel(BartPreTrainedModel): def __init_subclass__(self): @@ -792,7 +931,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No config.max_position_embeddings, embed_dim, ) - self.layers = nn.ModuleList([BartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layers = nn.ModuleList([BartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) @@ -976,7 +1115,7 @@ def __init__(self, config: BartConfig, embed_tokens: Optional[nn.Embedding] = No config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([BartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([BartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" @@ -1006,6 +1145,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -1071,6 +1211,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1079,43 +1222,59 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input = input_ids - input_shape = input.shape - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - input = inputs_embeds[:, :, -1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, + inputs_embeds = self.embed_tokens(input_ids) + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + output_attentions, + ) + # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: if self._use_flash_attention_2: @@ -1127,35 +1286,28 @@ def forward( encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( encoder_attention_mask, inputs_embeds.dtype, - tgt_len=input_shape[-1], + tgt_len=seq_length, ) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length ) # embed positions - positions = self.embed_positions(input, past_key_values_length) - positions = positions.to(inputs_embeds.device) + position_ids = self.embed_positions(input, past_key_values_length, position_ids=cache_position) + position_ids = position_ids.to(inputs_embeds.device) - hidden_states = inputs_embeds + positions + hidden_states = inputs_embeds + position_ids hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1175,13 +1327,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1189,25 +1339,27 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1220,6 +1372,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1296,6 +1451,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1379,6 +1535,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1467,6 +1624,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1569,6 +1727,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) @@ -1651,6 +1810,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1708,6 +1868,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] # last hidden state @@ -1795,6 +1956,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1844,6 +2006,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] @@ -1959,6 +2122,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -2008,6 +2172,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/big_bird/modeling_big_bird.py b/src/transformers/models/big_bird/modeling_big_bird.py index 5bae8894716e..2106c07e7dfd 100755 --- a/src/transformers/models/big_bird/modeling_big_bird.py +++ b/src/transformers/models/big_bird/modeling_big_bird.py @@ -1340,7 +1340,6 @@ def set_attention_type(self, value: str): attn_weights.value = self.self.value attn_weights.key = self.self.key self.self = attn_weights - self.attention_type = value if not self.training: self.self.eval() diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index efcf2dc44360..a1d972ec1acb 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -24,8 +24,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -36,10 +40,21 @@ Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_bigbird_pegasus import BigBirdPegasusConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) _EXPECTED_OUTPUT_SHAPE = [1, 7, 1024] @@ -69,13 +84,15 @@ class BigBirdPegasusLearnedPositionalEmbedding(nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int): super().__init__(num_embeddings, embedding_dim) - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): - """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: torch.Tensor = None): + """`input_ids' shape is expected to be [bsz x seqlen].""" + + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BigBirdPegasus @@ -1114,7 +1131,6 @@ def set_attention_type(self, value: str): if value == self.attention_type: return - self.attention_type = value if value == "original_full": # copy all weights to new full attention class attn_weights = BigBirdPegasusSelfAttention(self.config) @@ -1136,7 +1152,6 @@ def forward( hidden_states, attention_mask=None, head_mask=None, - past_key_value=None, output_attentions=False, band_mask=None, from_mask=None, @@ -1147,12 +1162,11 @@ def forward( # Expand dims to enable multiplication in the self-attention module head_mask = head_mask.reshape(1, -1, 1, 1) if head_mask is not None else None - if self.config.attention_type == "original_full": + if self.attention_type == "original_full": self_outputs = self.self( hidden_states, attention_mask, head_mask, - past_key_value=past_key_value, output_attentions=output_attentions, ) else: @@ -1178,6 +1192,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[BigBirdPegasusConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -1194,73 +1209,74 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -1274,10 +1290,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -1412,7 +1425,7 @@ def set_attention_type(self, value: str): class BigBirdPegasusDecoderLayer(nn.Module): - def __init__(self, config: BigBirdPegasusConfig): + def __init__(self, config: BigBirdPegasusConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model self.self_attn = BigBirdPegasusDecoderAttention( @@ -1421,6 +1434,7 @@ def __init__(self, config: BigBirdPegasusConfig): dropout=config.attention_dropout, is_decoder=True, bias=config.use_bias, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -1433,6 +1447,7 @@ def __init__(self, config: BigBirdPegasusConfig): dropout=config.attention_dropout, is_decoder=True, bias=config.use_bias, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -1448,9 +1463,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -1469,47 +1485,43 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -1525,7 +1537,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -1563,6 +1575,8 @@ class BigBirdPegasusPreTrainedModel(PreTrainedModel): _no_split_modules = ["BigBirdPegasusEncoderLayer", "BigBirdPegasusDecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_param_buffer_assignment = False + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -1574,6 +1588,9 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() @property def dummy_inputs(self): @@ -1585,6 +1602,131 @@ def dummy_inputs(self): } return dummy_inputs + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + class BigBirdPegasusEncoder(BigBirdPegasusPreTrainedModel): """ @@ -1914,7 +2056,9 @@ def __init__(self, config: BigBirdPegasusConfig, embed_tokens: Optional[nn.Embed config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([BigBirdPegasusDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [BigBirdPegasusDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1941,6 +2085,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ): r""" Args: @@ -2006,6 +2151,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -2014,54 +2162,80 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + output_attentions, ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length ) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) - positions = positions.to(inputs_embeds.device) - - hidden_states = inputs_embeds + positions - + position_ids = cache_position.unsqueeze(0) + position_ids = self.embed_positions( + (batch_size, seq_length), past_key_values_length, position_ids=position_ids + ) + position_ids = position_ids.to(inputs_embeds.device) + hidden_states = inputs_embeds + position_ids hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -2080,13 +2254,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -2094,25 +2266,27 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -2127,6 +2301,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -2198,6 +2375,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -2269,6 +2447,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -2359,6 +2538,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -2432,6 +2612,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) @@ -2514,6 +2695,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -2559,6 +2741,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] # last hidden state @@ -2646,6 +2829,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -2683,6 +2867,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] @@ -2794,6 +2979,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -2842,6 +3028,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/biogpt/modeling_biogpt.py b/src/transformers/models/biogpt/modeling_biogpt.py index 68de3b509e27..900a8fe17850 100755 --- a/src/transformers/models/biogpt/modeling_biogpt.py +++ b/src/transformers/models/biogpt/modeling_biogpt.py @@ -23,8 +23,11 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask, _prepare_4d_causal_attention_mask_for_sdpa +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, +) from ...modeling_outputs import ( BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions, @@ -34,11 +37,19 @@ from ...modeling_utils import PreTrainedModel from ...utils import ( auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, logging, ) from .configuration_biogpt import BioGptConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -55,17 +66,23 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0): + def forward( + self, + attention_mask: torch.LongTensor, + past_key_values_length: int = 0, + position_ids: Optional[torch.Tensor] = None, + ): """`input_ids_shape` is expected to be [bsz x seqlen].""" - attention_mask = attention_mask.long() + if position_ids is None: + attention_mask = attention_mask.long() - # create positions depending on attention_mask - positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 + # create positions depending on attention_mask + positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1 - # cut positions if `past_key_values_length` is > 0 - positions = positions[:, past_key_values_length:] + # cut positions if `past_key_values_length` is > 0 + position_ids = positions[:, past_key_values_length:] - return super().forward(positions + self.offset) + return super().forward(position_ids + self.offset) # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->BioGpt @@ -95,6 +112,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[BioGptConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -111,73 +129,74 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -191,10 +210,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -247,10 +263,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions: @@ -265,6 +282,7 @@ def forward( past_key_value=past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) # if key_value_states are provided this layer is used as a cross-attention layer @@ -274,50 +292,55 @@ def forward( bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + causal_mask = None + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 @@ -325,23 +348,16 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value @@ -354,7 +370,7 @@ def forward( class BioGptDecoderLayer(nn.Module): - def __init__(self, config: BioGptConfig): + def __init__(self, config: BioGptConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.hidden_size @@ -364,6 +380,7 @@ def __init__(self, config: BioGptConfig): dropout=config.attention_probs_dropout_prob, is_decoder=True, is_causal=True, + layer_idx=layer_idx, ) self.dropout = config.hidden_dropout_prob self.activation_fn = ACT2FN[config.hidden_act] @@ -380,9 +397,10 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -398,21 +416,22 @@ def forward( use_cache (`bool`, *optional*): If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see `past_key_values`). + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states @@ -433,7 +452,7 @@ def forward( outputs += (self_attn_weights,) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -444,6 +463,8 @@ class BioGptPreTrainedModel(PreTrainedModel): base_model_prefix = "biogpt" supports_gradient_checkpointing = True _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): """Initialize the weights""" @@ -461,6 +482,131 @@ def _init_weights(self, module): module.bias.data.zero_() module.weight.data.fill_(1.0) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + @auto_docstring class BioGptModel(BioGptPreTrainedModel): @@ -478,7 +624,7 @@ def __init__(self, config: BioGptConfig): ) self.embed_positions = BioGptLearnedPositionalEmbedding(config.max_position_embeddings, self.embed_dim) - self.layers = nn.ModuleList([BioGptDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([BioGptDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) self.layer_norm = nn.LayerNorm(self.embed_dim) self.gradient_checkpointing = False @@ -501,9 +647,11 @@ def forward( inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, # NOOP kwargs, for now ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -514,64 +662,71 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - input = input_ids - input_shape = input.size() - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - input = inputs_embeds[:, :, -1] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input) + inputs_embeds = self.embed_tokens(input_ids) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False - if attention_mask is None: - attention_mask = torch.ones( - (inputs_embeds.shape[0], inputs_embeds.shape[1] + past_key_values_length), - dtype=torch.bool, - device=inputs_embeds.device, + # initialize past_key_values + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - elif attention_mask.shape[1] != past_key_values_length + input_shape[1]: - raise ValueError( - f"The provided attention mask has length {attention_mask.shape[1]}, but its length should be " - f"{past_key_values_length + input_shape[1]} (sum of the lengths of current and past inputs)" + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) - # embed positions - positions = self.embed_positions(attention_mask, past_key_values_length) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if self._use_sdpa and not output_attentions and head_mask is None: - # output_attentions=True & head_mask can not be supported when using SDPA, fall back to - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) - else: - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length - ) + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + output_attentions, + ) - hidden_states = inputs_embeds + positions + # embed positions + if position_ids is None: + position_ids = cache_position.unsqueeze(0) - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + position_ids = self.embed_positions(attention_mask, past_key_values_length, position_ids=position_ids) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False + hidden_states = inputs_embeds + position_ids + hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -582,32 +737,32 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, head_mask[idx] if head_mask is not None else None, None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -619,6 +774,8 @@ def forward( hidden_states = self.layer_norm(hidden_states) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() if not return_dict: return tuple( @@ -668,9 +825,11 @@ def forward( past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, + position_ids: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, **kwargs, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" @@ -688,9 +847,11 @@ def forward( inputs_embeds=inputs_embeds, past_key_values=past_key_values, use_cache=use_cache, + position_ids=position_ids, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] diff --git a/src/transformers/models/blenderbot/modeling_blenderbot.py b/src/transformers/models/blenderbot/modeling_blenderbot.py index 1eef92764281..ede764a513c2 100755 --- a/src/transformers/models/blenderbot/modeling_blenderbot.py +++ b/src/transformers/models/blenderbot/modeling_blenderbot.py @@ -26,8 +26,12 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -36,11 +40,22 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from ..blenderbot_small import BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel from .configuration_blenderbot import BlenderbotConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -69,13 +84,16 @@ class BlenderbotLearnedPositionalEmbedding(nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int): super().__init__(num_embeddings, embedding_dim) - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ): """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->Blenderbot @@ -105,6 +123,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[BlenderbotConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -121,73 +140,74 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -201,10 +221,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -325,7 +342,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Blenderbot, MBART->BLENDERBOT class BlenderbotDecoderLayer(nn.Module): - def __init__(self, config: BlenderbotConfig): + def __init__(self, config: BlenderbotConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -336,6 +353,7 @@ def __init__(self, config: BlenderbotConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -348,6 +366,7 @@ def __init__(self, config: BlenderbotConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -362,9 +381,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -383,47 +403,43 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -439,7 +455,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -449,6 +465,8 @@ class BlenderbotPreTrainedModel(PreTrainedModel): config_class = BlenderbotConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -460,6 +478,9 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() @property def dummy_inputs(self): @@ -472,6 +493,131 @@ def dummy_inputs(self): } return dummy_inputs + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + class BlenderbotEncoder(BlenderbotPreTrainedModel): """ @@ -674,7 +820,9 @@ def __init__(self, config: BlenderbotConfig, embed_tokens: Optional[nn.Embedding config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([BlenderbotDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [BlenderbotDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -701,6 +849,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position: Optional[torch.Tensor] = None, ): r""" Args: @@ -767,6 +916,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -775,52 +927,79 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + ## retrieve input_ids and inputs_embeds + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + output_attentions, ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length ) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) - - hidden_states = inputs_embeds + positions + position_ids = self.embed_positions( + (batch_size, seq_length), past_key_values_length, position_ids=cache_position + ) + hidden_states = inputs_embeds + position_ids hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -839,13 +1018,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -853,25 +1030,27 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -887,6 +1066,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -963,6 +1145,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1041,6 +1224,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1139,6 +1323,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1225,6 +1410,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias @@ -1326,6 +1512,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1375,6 +1562,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py index eefb467e6059..a6fa6cda00e5 100755 --- a/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py +++ b/src/transformers/models/blenderbot_small/modeling_blenderbot_small.py @@ -24,8 +24,12 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -34,10 +38,21 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_blenderbot_small import BlenderbotSmallConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -67,13 +82,16 @@ class BlenderbotSmallLearnedPositionalEmbedding(nn.Embedding): def __init__(self, num_embeddings: int, embedding_dim: int): super().__init__(num_embeddings, embedding_dim) - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0): + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ): """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->BlenderbotSmall @@ -89,6 +107,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[BlenderbotSmallConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -105,73 +124,74 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -185,10 +205,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -237,7 +254,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL class BlenderbotSmallEncoderLayer(nn.Module): - def __init__(self, config: BlenderbotSmallConfig): + def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -246,6 +263,7 @@ def __init__(self, config: BlenderbotSmallConfig): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, + layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -314,7 +332,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->BlenderbotSmall, BART->BLENDERBOT_SMALL class BlenderbotSmallDecoderLayer(nn.Module): - def __init__(self, config: BlenderbotSmallConfig): + def __init__(self, config: BlenderbotSmallConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -325,6 +343,7 @@ def __init__(self, config: BlenderbotSmallConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -337,6 +356,7 @@ def __init__(self, config: BlenderbotSmallConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -351,9 +371,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -372,47 +393,42 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -428,7 +444,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -438,6 +454,8 @@ class BlenderbotSmallPreTrainedModel(PreTrainedModel): config_class = BlenderbotSmallConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -449,6 +467,9 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() @property def dummy_inputs(self): @@ -461,6 +482,131 @@ def dummy_inputs(self): } return dummy_inputs + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + class BlenderbotSmallEncoder(BlenderbotSmallPreTrainedModel): """ @@ -657,7 +803,9 @@ def __init__(self, config: BlenderbotSmallConfig, embed_tokens: Optional[nn.Embe config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([BlenderbotSmallDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [BlenderbotSmallDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -684,6 +832,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -749,6 +898,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -757,55 +909,83 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) + + inputs_embeds = inputs_embeds * self.embed_scale + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + output_attentions, ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length ) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) + position_ids = self.embed_positions( + (batch_size, seq_length), past_key_values_length, position_ids=cache_position + ) # BlenderbotSmall applies layer norm on hidden_states inputs_embeds = self.layernorm_embedding(inputs_embeds) - hidden_states = inputs_embeds + positions - + hidden_states = inputs_embeds + position_ids hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -824,13 +1004,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -838,25 +1016,27 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -869,6 +1049,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -932,6 +1115,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1010,6 +1194,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1093,6 +1278,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1179,6 +1365,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias @@ -1280,6 +1467,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1329,6 +1517,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/informer/modeling_informer.py b/src/transformers/models/informer/modeling_informer.py index 98acd1bd13be..8b728a19dff7 100644 --- a/src/transformers/models/informer/modeling_informer.py +++ b/src/transformers/models/informer/modeling_informer.py @@ -21,7 +21,10 @@ from torch import nn from ...activations import ACT2FN -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -251,13 +254,16 @@ def _init_weight(self): self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesValueEmbedding with TimeSeries->Info @@ -270,7 +276,7 @@ def forward(self, x): return self.value_projection(x) -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Informer +# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Informer class InformerAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -1045,7 +1051,6 @@ def forward( ) -# Copied from transformers.models.time_series_transformer.modeling_time_series_transformer.TimeSeriesTransformerDecoder with TimeSeriesTransformer->Informer,TimeSeriesTransformerConfig->InformerConfig,time-series-transformer->informer,Transformer->Informer,TimeSeries->Informer class InformerDecoder(InformerPreTrainedModel): """ Informer decoder consisting of *config.decoder_layers* layers. Each layer is a @@ -1403,6 +1408,7 @@ def get_encoder(self): def get_decoder(self): return self.decoder + # Ignore copy @auto_docstring def forward( self, @@ -1654,6 +1660,7 @@ def output_distribution(self, params, loc=None, scale=None, trailing_n=None) -> sliced_params = [p[:, -trailing_n:] for p in params] return self.distribution_output.distribution(sliced_params, loc=loc, scale=scale) + # Ignore copy @auto_docstring def forward( self, diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index 39503193be71..1da696ae0386 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -22,14 +22,14 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( @@ -39,10 +39,21 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_m2m_100 import M2M100Config +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -186,6 +197,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[M2M100Config] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -202,73 +214,74 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -282,10 +295,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -348,73 +358,73 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # M2M100FlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError( + "M2M100SdpaAttention2 attention does not support `output_attentions`. " + "Use the argument `attn_implementation='eager'` when loading the model." + ) + # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, q_len, _ = hidden_states.size() # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): @@ -449,10 +459,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, None, past_key_value # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->M2M100 @@ -461,10 +468,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions: @@ -479,6 +487,7 @@ def forward( past_key_value=past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) # if key_value_states are provided this layer is used as a cross-attention layer @@ -488,50 +497,55 @@ def forward( bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + causal_mask = None + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 @@ -539,23 +553,16 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value @@ -639,7 +646,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100, MBART->M2M100 class M2M100DecoderLayer(nn.Module): - def __init__(self, config: M2M100Config): + def __init__(self, config: M2M100Config, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -650,6 +657,7 @@ def __init__(self, config: M2M100Config): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -662,6 +670,7 @@ def __init__(self, config: M2M100Config): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -676,9 +685,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -697,47 +707,43 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -753,7 +759,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -766,6 +772,9 @@ class M2M100PreTrainedModel(PreTrainedModel): _no_split_modules = ["M2M100EncoderLayer", "M2M100DecoderLayer"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True + # Doesn't support `compile` (dynamic control flow). Can be fixed but low usage model + _supports_static_cache = False def _init_weights(self, module): std = self.config.init_std @@ -777,6 +786,134 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask class M2M100Encoder(M2M100PreTrainedModel): @@ -995,7 +1132,7 @@ def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = config.d_model, self.padding_idx, ) - self.layers = nn.ModuleList([M2M100DecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([M2M100DecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" self.layer_norm = nn.LayerNorm(config.d_model) @@ -1018,6 +1155,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ): r""" Args: @@ -1083,6 +1221,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1091,41 +1232,59 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self._use_flash_attention_2: - # 2d mask is passed through the layers - combined_attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - combined_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - else: - # 4d mask is passed through the layers - combined_attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + output_attentions, + ) + # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: if self._use_flash_attention_2: @@ -1137,12 +1296,14 @@ def forward( encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( encoder_attention_mask, inputs_embeds.dtype, - tgt_len=input_shape[-1], + tgt_len=seq_length, ) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=seq_length, ) # embed positions @@ -1150,7 +1311,6 @@ def forward( positions = positions.to(inputs_embeds.device) hidden_states = inputs_embeds + positions - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) if self.gradient_checkpointing and self.training: @@ -1164,7 +1324,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if output_attentions else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1187,13 +1347,11 @@ def forward( if not skip_the_layer or synced_gpus: # under fsdp or deepspeed zero3 all gpus must run in sync - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - combined_attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1201,20 +1359,22 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=combined_attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -1223,7 +1383,7 @@ def forward( continue if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1236,6 +1396,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1305,6 +1468,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1367,6 +1531,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1432,6 +1597,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1499,6 +1665,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index f544641e0d65..ee52dd1be069 100755 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -25,8 +25,12 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -35,10 +39,21 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_marian import MarianConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -81,13 +96,16 @@ def _init_weight(self): self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Marian @@ -103,6 +121,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[MarianConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -119,73 +138,74 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -199,10 +219,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -251,7 +268,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->Marian, BART->MARIAN class MarianEncoderLayer(nn.Module): - def __init__(self, config: MarianConfig): + def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -260,6 +277,7 @@ def __init__(self, config: MarianConfig): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, + layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -325,7 +343,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->Marian, BART->MARIAN class MarianDecoderLayer(nn.Module): - def __init__(self, config: MarianConfig): + def __init__(self, config: MarianConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -336,6 +354,7 @@ def __init__(self, config: MarianConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -348,6 +367,7 @@ def __init__(self, config: MarianConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -362,9 +382,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -383,47 +404,42 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -439,7 +455,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -449,6 +465,8 @@ class MarianPreTrainedModel(PreTrainedModel): config_class = MarianConfig base_model_prefix = "model" supports_gradient_checkpointing = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalPositionalEmbedding]): std = self.config.init_std @@ -462,6 +480,134 @@ def _init_weights(self, module: Union[nn.Linear, nn.Embedding, MarianSinusoidalP module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask @property def dummy_inputs(self): @@ -670,7 +816,7 @@ def __init__(self, config: MarianConfig, embed_tokens: Optional[nn.Embedding] = self.embed_positions = MarianSinusoidalPositionalEmbedding( config.max_position_embeddings, config.d_model, self.padding_idx ) - self.layers = nn.ModuleList([MarianDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([MarianDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.gradient_checkpointing = False # Initialize weights and apply final processing @@ -696,6 +842,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -761,6 +908,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -769,53 +919,83 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) + + # Important to apply outside of the above `if`, in case user passes `embeds` + inputs_embeds = inputs_embeds * self.embed_scale + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + output_attentions, ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + encoder_attention_mask, + inputs_embeds.dtype, + tgt_len=seq_length, ) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) - - hidden_states = inputs_embeds + positions - + position_ids = self.embed_positions( + (batch_size, seq_length), past_key_values_length, position_ids=cache_position + ) + hidden_states = inputs_embeds + position_ids hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -833,13 +1013,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -847,25 +1025,27 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -878,6 +1058,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -996,6 +1179,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Seq2SeqModelOutput: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1078,6 +1262,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1252,6 +1437,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Seq2SeqLMOutput: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1326,6 +1512,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias @@ -1430,6 +1617,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1479,6 +1667,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 88c0d79af34c..1cc334db7e0f 100755 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -24,12 +24,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_flash_attention_utils import flash_attn_supports_top_left_mask, is_flash_attn_available from ...modeling_outputs import ( @@ -42,10 +42,21 @@ Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_mbart import MBartConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + if is_flash_attn_available(): from ...modeling_flash_attention_utils import _flash_attention_forward @@ -85,15 +96,18 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor = None): """`input_ids' shape is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids.shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ).expand(bsz, -1) + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + else: + position_ids = position_ids.unsqueeze(0) - return super().forward(positions + self.offset) + return super().forward(position_ids + self.offset) # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->MBart @@ -123,6 +137,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[MBartConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -139,73 +154,74 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -219,10 +235,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -285,73 +298,73 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # MBartFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError( + "MBartSdpaAttention2 attention does not support `output_attentions`. " + "Use the argument `attn_implementation='eager'` when loading the model." + ) + # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, q_len, _ = hidden_states.size() # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): @@ -386,10 +399,7 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, None, past_key_value # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MBart @@ -398,10 +408,11 @@ def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions: @@ -416,6 +427,7 @@ def forward( past_key_value=past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) # if key_value_states are provided this layer is used as a cross-attention layer @@ -425,50 +437,55 @@ def forward( bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True + + causal_mask = None + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 @@ -476,23 +493,16 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) return attn_output, None, past_key_value @@ -574,7 +584,7 @@ def forward( class MBartDecoderLayer(nn.Module): - def __init__(self, config: MBartConfig): + def __init__(self, config: MBartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -585,6 +595,7 @@ def __init__(self, config: MBartConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -597,6 +608,7 @@ def __init__(self, config: MBartConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -611,9 +623,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -632,47 +645,43 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -688,7 +697,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -726,6 +735,8 @@ class MBartPreTrainedModel(PreTrainedModel): _no_split_modules = ["MBartDecoderLayer", "MBartEncoderLayer", "MBartAttention"] _supports_flash_attn_2 = True _supports_sdpa = True + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -737,6 +748,9 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() @property def dummy_inputs(self): @@ -748,6 +762,131 @@ def dummy_inputs(self): } return dummy_inputs + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + class MBartEncoder(MBartPreTrainedModel): """ @@ -965,7 +1104,7 @@ def __init__(self, config: MBartConfig, embed_tokens: Optional[nn.Embedding] = N config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([MBartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([MBartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.config = config self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -995,6 +1134,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -1060,6 +1200,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1069,42 +1212,58 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input = input_ids - input_shape = input.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - input = inputs_embeds[:, :, -1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if self.config._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self.config._attn_implementation == "sdpa" and not output_attentions and cross_attn_head_mask is None: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." + ) + use_cache = False + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + output_attentions, + ) + # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: if self.config._attn_implementation == "flash_attention_2": @@ -1116,34 +1275,27 @@ def forward( encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( encoder_attention_mask, inputs_embeds.dtype, - tgt_len=input_shape[-1], + tgt_len=seq_length, ) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length ) # embed positions - positions = self.embed_positions(input, past_key_values_length) + position_ids = self.embed_positions(input, past_key_values_length, position_ids=cache_position) - hidden_states = inputs_embeds + positions.to(inputs_embeds.device) + hidden_states = inputs_embeds + position_ids.to(inputs_embeds.device) hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing`. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -1162,13 +1314,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -1176,25 +1326,27 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1209,6 +1361,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1278,6 +1433,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Seq2SeqModelOutput, Tuple[torch.FloatTensor]]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1350,6 +1506,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1433,6 +1590,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Seq2SeqLMOutput, Tuple[torch.FloatTensor]]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1529,6 +1687,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias @@ -1609,6 +1768,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1666,6 +1826,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] # last hidden state @@ -1754,6 +1915,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqQuestionAnsweringModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1803,6 +1965,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) sequence_output = outputs[0] @@ -1915,6 +2078,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1964,6 +2128,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/musicgen/modeling_musicgen.py b/src/transformers/models/musicgen/modeling_musicgen.py index 835756cb7703..00a0afd62ad5 100644 --- a/src/transformers/models/musicgen/modeling_musicgen.py +++ b/src/transformers/models/musicgen/modeling_musicgen.py @@ -146,7 +146,7 @@ def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): return self.weights.index_select(0, position_ids.view(-1)).detach() -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Musicgen +# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Musicgen class MusicgenAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -305,7 +305,7 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Musicgen +# Copied from transformers.models.hubert.modeling_hubert.HubertFlashAttention2 with Hubert->Musicgen class MusicgenFlashAttention2(MusicgenAttention): """ Musicgen flash attention module. This module inherits from `MusicgenAttention` as the weights of the module stays @@ -589,7 +589,7 @@ def __init__(self, config: MusicgenDecoderConfig): self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim, bias=False) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward + # copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py index 84efc8928872..c5328d064769 100644 --- a/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py +++ b/src/transformers/models/musicgen_melody/modeling_musicgen_melody.py @@ -159,7 +159,7 @@ def forward(self, inputs_embeds: torch.Tensor, past_key_values_length: int = 0): return self.weights.index_select(0, position_ids.view(-1)).detach() -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->MusicgenMelody +# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->MusicgenMelody class MusicgenMelodyAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -318,7 +318,7 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->MusicgenMelody +# Copied from transformers.models.hubert.modeling_hubert.HubertFlashAttention2 with Hubert->MusicgenMelody class MusicgenMelodyFlashAttention2(MusicgenMelodyAttention): """ MusicgenMelody flash attention module. This module inherits from `MusicgenMelodyAttention` as the weights of the module stays @@ -441,7 +441,7 @@ def forward( return attn_output, attn_weights, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartSdpaAttention with Bart->MusicgenMelody +# Copied from transformers.models.hubert.modeling_hubert.HubertSdpaAttention with Hubert->MusicgenMelody class MusicgenMelodySdpaAttention(MusicgenMelodyAttention): def forward( self, diff --git a/src/transformers/models/mvp/modeling_mvp.py b/src/transformers/models/mvp/modeling_mvp.py index 1c3862349681..29874999e052 100644 --- a/src/transformers/models/mvp/modeling_mvp.py +++ b/src/transformers/models/mvp/modeling_mvp.py @@ -25,7 +25,10 @@ from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -60,27 +63,30 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start return shifted_input_ids -# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->MVP +# Copied from transformers.models.bart.modeling_bart.BartLearnedPositionalEmbedding with Bart->Mvp class MvpLearnedPositionalEmbedding(nn.Embedding): """ This module learns positional embeddings up to a fixed maximum size. """ def __init__(self, num_embeddings: int, embedding_dim: int): - # MVP is set up so that if padding_idx is specified then offset the embedding ids by 2 + # Mvp is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor = None): """`input_ids' shape is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids.shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ).expand(bsz, -1) + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + else: + position_ids = position_ids.unsqueeze(0) - return super().forward(positions + self.offset) + return super().forward(position_ids + self.offset) class MvpAttention(nn.Module): diff --git a/src/transformers/models/nllb_moe/modeling_nllb_moe.py b/src/transformers/models/nllb_moe/modeling_nllb_moe.py index 79002aee3178..3729fd6e1815 100644 --- a/src/transformers/models/nllb_moe/modeling_nllb_moe.py +++ b/src/transformers/models/nllb_moe/modeling_nllb_moe.py @@ -460,7 +460,7 @@ def forward(self, hidden_states: torch.Tensor, padding_mask: Optional[torch.Tens return hidden_states, (router_probs, top_1_expert_index) -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->NllbMoe,key_value_states->encoder_hidden_states +# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->NllbMoe,key_value_states->encoder_hidden_states class NllbMoeAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py index 983961f3b4e2..c10a55b1b024 100644 --- a/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py +++ b/src/transformers/models/patchtsmixer/modeling_patchtsmixer.py @@ -26,6 +26,7 @@ from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_patchtsmixer import PatchTSMixerConfig @@ -247,6 +248,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[PatchTSMixerConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -263,15 +265,23 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - + # Ignore copy + @deprecate_kwarg("key_value_states", version="4.55") + @deprecate_kwarg("past_key_value", version="4.55") + @deprecate_kwarg("cache_position", version="4.55") def forward( self, hidden_states: torch.Tensor, @@ -280,56 +290,21 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -343,10 +318,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -390,7 +362,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped, None class PatchMixerBlock(nn.Module): diff --git a/src/transformers/models/patchtst/modeling_patchtst.py b/src/transformers/models/patchtst/modeling_patchtst.py index e5785974be5b..57b69b6b5ebc 100755 --- a/src/transformers/models/patchtst/modeling_patchtst.py +++ b/src/transformers/models/patchtst/modeling_patchtst.py @@ -32,7 +32,7 @@ logger = logging.get_logger(__name__) -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->PatchTST +# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->PatchTST class PatchTSTAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 9d0eafa9c94d..9f57f79df4ed 100755 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -25,8 +25,12 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -35,10 +39,21 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_pegasus import PegasusConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -82,13 +97,16 @@ def _init_weight(self): self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) # Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Pegasus @@ -104,6 +122,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[PegasusConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -120,73 +139,74 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -200,10 +220,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -324,7 +341,7 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Pegasus, MBART->PEGASUS class PegasusDecoderLayer(nn.Module): - def __init__(self, config: PegasusConfig): + def __init__(self, config: PegasusConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -335,6 +352,7 @@ def __init__(self, config: PegasusConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -347,6 +365,7 @@ def __init__(self, config: PegasusConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -361,9 +380,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -382,47 +402,43 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -438,7 +454,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -461,6 +477,134 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask class PegasusEncoder(PegasusPreTrainedModel): @@ -692,7 +836,7 @@ def __init__(self, config: PegasusConfig, embed_tokens: Optional[nn.Embedding] = config.d_model, self.padding_idx, ) - self.layers = nn.ModuleList([PegasusDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([PegasusDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -749,6 +893,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -814,6 +959,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -822,53 +970,79 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale + inputs_embeds = self.embed_tokens(input_ids) + + # important to apply scale outside of `if` in case users pass `embeds` + inputs_embeds = inputs_embeds * self.embed_scale + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + output_attentions, ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length ) # embed positions - positions = self.embed_positions(input_shape, past_key_values_length) - + positions = self.embed_positions((batch_size, seq_length), past_key_values_length, position_ids=cache_position) hidden_states = inputs_embeds + positions - hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -887,13 +1061,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -901,25 +1073,27 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -934,6 +1108,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1020,6 +1197,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1099,6 +1277,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1205,6 +1384,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1280,6 +1460,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) + self.final_logits_bias @@ -1406,6 +1587,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1455,6 +1637,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/pegasus_x/modeling_pegasus_x.py b/src/transformers/models/pegasus_x/modeling_pegasus_x.py index 9656b3afe77f..c9554255e470 100755 --- a/src/transformers/models/pegasus_x/modeling_pegasus_x.py +++ b/src/transformers/models/pegasus_x/modeling_pegasus_x.py @@ -25,8 +25,12 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, + _prepare_4d_attention_mask, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -34,10 +38,21 @@ Seq2SeqModelOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_pegasus_x import PegasusXConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -99,20 +114,24 @@ def __init__(self, embed_dim, max_scale: int = 10000.0): self.max_scale = max_scale @torch.no_grad() - def forward(self, input_embeds: torch.Tensor, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_embeds: torch.Tensor, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" batch_size, seq_len = input_embeds.shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=input_embeds.device - )[:, None] + if position_ids is None: + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=input_embeds.device + )[:, None] + pe = torch.zeros((seq_len, self.embed_dim), device=input_embeds.device, dtype=input_embeds.dtype) half_d_feature = self.embed_dim // 2 div_term = torch.exp( torch.arange(half_d_feature, device=input_embeds.device, dtype=torch.int64).type_as(input_embeds) * -(np.log(float(self.max_scale)) / (half_d_feature - 1)) ) - pe[:, :half_d_feature] = torch.sin(positions * div_term) - pe[:, half_d_feature:] = torch.cos(positions * div_term) + pe[:, :half_d_feature] = torch.sin(position_ids * div_term) + pe[:, half_d_feature:] = torch.cos(position_ids * div_term) return pe[None].expand(batch_size, -1, -1) @@ -129,6 +148,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[PegasusXConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -145,73 +165,74 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -225,10 +246,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -625,7 +643,7 @@ def unpad_local_tokens(cls, padded_hidden_states, block_size): class PegasusXDecoderLayer(nn.Module): - def __init__(self, config: PegasusXConfig): + def __init__(self, config: PegasusXConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -635,6 +653,7 @@ def __init__(self, config: PegasusXConfig): dropout=config.attention_dropout, is_decoder=True, bias=False, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -647,6 +666,7 @@ def __init__(self, config: PegasusXConfig): dropout=config.attention_dropout, is_decoder=True, bias=False, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -659,9 +679,10 @@ def forward( attention_mask: Optional[torch.Tensor] = None, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Args: @@ -677,45 +698,40 @@ def forward( Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. use_cache: Whether to us KV cache for decoding + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -731,7 +747,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -742,6 +758,8 @@ class PegasusXPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = [r"PegasusXEncoderLayer", r"PegasusXDecoderLayer"] + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -751,6 +769,134 @@ def _init_weights(self, module): module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask class PegasusXEncoder(PegasusXPreTrainedModel): @@ -990,7 +1136,7 @@ def __init__(self, config: PegasusXConfig, embed_tokens: Optional[nn.Embedding] ) self.embed_positions = PegasusXSinusoidalPositionalEmbedding(config.d_model) - self.layers = nn.ModuleList([PegasusXDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([PegasusXDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self.layer_norm = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -1015,6 +1161,7 @@ def forward( output_attentions=None, output_hidden_states=None, return_dict=None, + cache_position=None, ): r""" Args: @@ -1069,6 +1216,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1077,55 +1227,78 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input_shape = input_ids.size() - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) + + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + output_attentions, ) # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length ) # embed positions - positions = self.embed_positions(inputs_embeds, past_key_values_length) - - positions = positions.to(inputs_embeds.device) - - hidden_states = inputs_embeds + positions - + position_ids = cache_position.unsqueeze(1) + position_ids = self.embed_positions(inputs_embeds, past_key_values_length, position_ids) + position_ids = position_ids.to(inputs_embeds.device) + hidden_states = inputs_embeds + position_ids hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None for idx, decoder_layer in enumerate(self.layers): # add LayerDrop (see https://arxiv.org/abs/1909.11556 for description) @@ -1136,33 +1309,33 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -1177,6 +1350,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1264,6 +1440,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1334,6 +1511,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1419,6 +1597,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple, Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1464,6 +1643,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/plbart/modeling_plbart.py b/src/transformers/models/plbart/modeling_plbart.py index 22157b29c775..c82cfb906976 100644 --- a/src/transformers/models/plbart/modeling_plbart.py +++ b/src/transformers/models/plbart/modeling_plbart.py @@ -24,12 +24,12 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...generation import GenerationMixin from ...modeling_attn_mask_utils import ( + AttentionMaskConverter, _prepare_4d_attention_mask, _prepare_4d_attention_mask_for_sdpa, - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, ) from ...modeling_outputs import ( BaseModelOutput, @@ -40,10 +40,21 @@ Seq2SeqSequenceClassifierOutput, ) from ...modeling_utils import PreTrainedModel -from ...utils import auto_docstring, logging +from ...utils import ( + auto_docstring, + is_torch_flex_attn_available, + is_torchdynamo_compiling, + logging, +) from .configuration_plbart import PLBartConfig +if is_torch_flex_attn_available(): + from torch.nn.attention.flex_attention import BlockMask + + from ...integrations.flex_attention import make_flex_block_causal_mask + + logger = logging.get_logger(__name__) @@ -80,15 +91,18 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor = None): """`input_ids' shape is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids.shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ).expand(bsz, -1) + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + else: + position_ids = position_ids.unsqueeze(0) - return super().forward(positions + self.offset) + return super().forward(position_ids + self.offset) # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->PLBart @@ -118,6 +132,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[PLBartConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -134,73 +149,74 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -214,10 +230,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -266,7 +279,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->PLBart, BART->PLBART class PLBartEncoderLayer(nn.Module): - def __init__(self, config: PLBartConfig): + def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -275,6 +288,7 @@ def __init__(self, config: PLBartConfig): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, + layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -341,7 +355,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->PLBart, BART->PLBART class PLBartDecoderLayer(nn.Module): - def __init__(self, config: PLBartConfig): + def __init__(self, config: PLBartConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -352,6 +366,7 @@ def __init__(self, config: PLBartConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -364,6 +379,7 @@ def __init__(self, config: PLBartConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -378,9 +394,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -399,47 +416,42 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -455,7 +467,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -491,6 +503,8 @@ class PLBartPreTrainedModel(PreTrainedModel): base_model_prefix = "model" supports_gradient_checkpointing = True _no_split_modules = ["PLBartDecoderLayer", "PLBartEncoderLayer"] + _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.init_std @@ -502,6 +516,134 @@ def _init_weights(self, module): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.weight.data.fill_(1.0) + module.bias.data.zero_() + + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: Union[torch.Tensor, "BlockMask"], + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool = False, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and (attention_mask == 0.0).any(): + return attention_mask + return None + if self.config._attn_implementation == "flex_attention": + if isinstance(attention_mask, torch.Tensor): + attention_mask = make_flex_block_causal_mask(attention_mask) + return attention_mask + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype = input_tensor.dtype + sequence_length = input_tensor.shape[1] + if using_compilable_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type in ["cuda", "xpu", "npu"] + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to( + causal_mask.device + ) + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask # Copied from transformers.models.bart.modeling_bart.BartEncoder with Bart->PLBart @@ -537,7 +679,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = config.max_position_embeddings, embed_dim, ) - self.layers = nn.ModuleList([PLBartEncoderLayer(config) for _ in range(config.encoder_layers)]) + self.layers = nn.ModuleList([PLBartEncoderLayer(config, layer_idx=i) for i in range(config.encoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" self.layernorm_embedding = nn.LayerNorm(embed_dim) @@ -722,7 +864,7 @@ def __init__(self, config: PLBartConfig, embed_tokens: Optional[nn.Embedding] = config.max_position_embeddings, config.d_model, ) - self.layers = nn.ModuleList([PLBartDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList([PLBartDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)]) self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" self._use_sdpa = config._attn_implementation == "sdpa" @@ -752,6 +894,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -817,6 +960,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -825,43 +971,59 @@ def forward( use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: + if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") - elif input_ids is not None: - input = input_ids - input_shape = input.shape - input_ids = input_ids.view(-1, input_shape[-1]) - elif inputs_embeds is not None: - input_shape = inputs_embeds.size()[:-1] - input = inputs_embeds[:, :, -1] - else: - raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + if input_ids is not None: + input_ids = input_ids.view(-1, input_ids.shape[-1]) if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input) - - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions and cross_attn_head_mask is None: - # output_attentions=True & cross_attn_head_mask can not be supported when using SDPA, and we fall back on - # the manual implementation that requires a 4D causal mask in all cases. - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - input_shape, - inputs_embeds, - past_key_values_length, + inputs_embeds = self.embed_tokens(input_ids) + + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, input_shape, inputs_embeds, past_key_values_length + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + + batch_size, seq_length = inputs_embeds.size()[:-1] + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device ) + if attention_mask is None and not is_torchdynamo_compiling(): + # required mask seq length can be calculated via length of past cache + mask_seq_length = past_key_values_length + seq_length + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + + self_attn_cache = ( + past_key_values.self_attention_cache + if isinstance(past_key_values, EncoderDecoderCache) + else past_key_values + ) + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + self_attn_cache, + output_attentions, + ) + # expand encoder attention mask if encoder_hidden_states is not None and encoder_attention_mask is not None: if self._use_flash_attention_2: @@ -873,35 +1035,28 @@ def forward( encoder_attention_mask = _prepare_4d_attention_mask_for_sdpa( encoder_attention_mask, inputs_embeds.dtype, - tgt_len=input_shape[-1], + tgt_len=seq_length, ) else: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] encoder_attention_mask = _prepare_4d_attention_mask( - encoder_attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] + encoder_attention_mask, inputs_embeds.dtype, tgt_len=seq_length ) # embed positions - positions = self.embed_positions(input, past_key_values_length) - positions = positions.to(inputs_embeds.device) + position_ids = self.embed_positions(input, past_key_values_length, position_ids=cache_position) + position_ids = position_ids.to(inputs_embeds.device) - hidden_states = inputs_embeds + positions + hidden_states = inputs_embeds + position_ids hidden_states = self.layernorm_embedding(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -921,13 +1076,11 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, encoder_hidden_states, encoder_attention_mask, head_mask[idx] if head_mask is not None else None, @@ -935,25 +1088,27 @@ def forward( None, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -966,6 +1121,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1034,6 +1192,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqModelOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1108,6 +1267,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1190,6 +1350,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, ) -> Union[Tuple[torch.Tensor], Seq2SeqLMOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1267,6 +1428,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) lm_logits = self.lm_head(outputs[0]) lm_logits = lm_logits + self.final_logits_bias.to(lm_logits.device) @@ -1348,6 +1510,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, Seq2SeqSequenceClassifierOutput]: r""" decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*): @@ -1405,6 +1568,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] # last hidden state @@ -1522,6 +1686,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]: r""" cross_attn_head_mask (`torch.Tensor` of shape `(decoder_layers, decoder_attention_heads)`, *optional*): @@ -1571,6 +1736,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) logits = self.lm_head(outputs[0]) diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index 5218a48e480b..e4208fc8cfd4 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -127,7 +127,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() diff --git a/src/transformers/models/roformer/modeling_roformer.py b/src/transformers/models/roformer/modeling_roformer.py index 7e4f1ff9f28b..1fdccf728463 100644 --- a/src/transformers/models/roformer/modeling_roformer.py +++ b/src/transformers/models/roformer/modeling_roformer.py @@ -67,13 +67,16 @@ def _init_weight(self): self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) def load_tf_weights_in_roformer(model, config, tf_checkpoint_path): diff --git a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py index d2feaafa8581..42f3e4b577c3 100755 --- a/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py +++ b/src/transformers/models/seamless_m4t/modeling_seamless_m4t.py @@ -28,7 +28,10 @@ from ...generation import GenerationMixin from ...integrations.deepspeed import is_deepspeed_zero3_enabled from ...integrations.fsdp import is_fsdp_managed_module -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPastAndCrossAttentions, @@ -1003,6 +1006,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[SeamlessM4TConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -1019,6 +1023,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py index 69dc83296731..c62b97fb89ce 100644 --- a/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +++ b/src/transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py @@ -909,6 +909,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[SeamlessM4Tv2Config] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -925,6 +926,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 812f8fb4bc74..b4d8584fe0ce 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -31,6 +31,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging +from ...utils.deprecation import deprecate_kwarg from .configuration_sew import SEWConfig @@ -385,6 +386,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[SEWConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -401,15 +403,23 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - + # Ignore copy + @deprecate_kwarg("key_value_states", version="4.55") + @deprecate_kwarg("past_key_value", version="4.55") + @deprecate_kwarg("cache_position", version="4.55") def forward( self, hidden_states: torch.Tensor, @@ -418,56 +428,21 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + key_states = self.k_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -481,10 +456,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -528,7 +500,7 @@ def forward( attn_output = self.out_proj(attn_output) - return attn_output, attn_weights_reshaped, past_key_value + return attn_output, attn_weights_reshaped, None # Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->SEW @@ -547,9 +519,10 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = flash_attn_supports_top_left_mask() - def _reshape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim) - + # Ignore copy + @deprecate_kwarg("key_value_states", version="4.55") + @deprecate_kwarg("past_key_value", version="4.55") + @deprecate_kwarg("cache_position", version="4.55") def forward( self, hidden_states: torch.Tensor, @@ -558,62 +531,30 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None + # SEWFlashAttention2 attention does not support output_attentions + if output_attentions: + raise ValueError( + "SEWSdpaAttention2 attention does not support `output_attentions`. " + "Use the argument `attn_implementation='eager'` when loading the model." + ) bsz, q_len, _ = hidden_states.size() # get query proj - query_states = self._reshape(self.q_proj(hidden_states), -1, bsz) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0].transpose(1, 2) - value_states = past_key_value[1].transpose(1, 2) - elif is_cross_attention: - # cross_attentions - key_states = self._reshape(self.k_proj(key_value_states), -1, bsz) - value_states = self._reshape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) - value_states = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) - else: - # self_attention - key_states = self._reshape(self.k_proj(hidden_states), -1, bsz) - value_states = self._reshape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states.transpose(1, 2), value_states.transpose(1, 2)) - - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value[0].shape[-2] + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim) # In PEFT, usually we cast the layer norms in float32 for training stability reasons # therefore the input hidden states gets silently casted in float32. Hence, we need # cast them back in the correct dtype just to be sure everything works as expected. # This might slowdown training & inference so it is recommended to not cast the LayerNorms # in fp32. (LlamaRMSNorm handles it correctly) - input_dtype = query_states.dtype if input_dtype == torch.float32: if torch.is_autocast_enabled(): @@ -648,14 +589,13 @@ def forward( attn_output = attn_output.reshape(bsz, q_len, -1) attn_output = self.out_proj(attn_output) - if not output_attentions: - attn_weights = None - - return attn_output, attn_weights, past_key_value + return attn_output, None, None class SEWSdpaAttention(SEWAttention): - # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->SEW + @deprecate_kwarg("key_value_states", version="4.55") + @deprecate_kwarg("past_key_value", version="4.55") + @deprecate_kwarg("cache_position", version="4.55") def forward( self, hidden_states: torch.Tensor, @@ -664,6 +604,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" if output_attentions: @@ -675,62 +616,35 @@ def forward( return super().forward( hidden_states, key_value_states=key_value_states, - past_key_value=past_key_value, attention_mask=attention_mask, output_attentions=output_attentions, ) - # if key_value_states are provided this layer is used as a cross-attention layer - # for the decoder - is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): - # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) - else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) - - query_states = self._shape(query_states, tgt_len, bsz) + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + causal_mask = None + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and causal_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case tgt_len == 1. - is_causal = True if self.is_causal and attention_mask is None and tgt_len > 1 else False + is_causal = True if self.is_causal and causal_mask is None and tgt_len > 1 else False # NOTE: SDPA with memory-efficient backend is currently (torch==2.1.2) bugged when using non-contiguous inputs and a custom attn_mask, # but we are fine here as `_shape` do call `.contiguous()`. Reference: https://github.com/pytorch/pytorch/issues/112577 @@ -738,26 +652,19 @@ def forward( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.dropout if self.training else 0.0, is_causal=is_causal, ) - if attn_output.size() != (bsz, self.num_heads, tgt_len, self.head_dim): - raise ValueError( - f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" - f" {attn_output.size()}" - ) - - attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.transpose(1, 2).contiguous() # Use the `embed_dim` from the config (stored in the class) rather than `hidden_state` because `attn_output` can be # partitioned across GPUs when using tensor-parallelism. - attn_output = attn_output.reshape(bsz, tgt_len, self.embed_dim) - + attn_output = attn_output.view(bsz, tgt_len, self.embed_dim) attn_output = self.out_proj(attn_output) - return attn_output, None, past_key_value + return attn_output, None, None SEW_ATTENTION_CLASSES = { diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index b5842217fce8..8db0674633fb 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -161,7 +161,7 @@ def create_position_ids_from_input_ids( return incremental_indices.long() + padding_idx -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Speech2Text +# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Speech2Text class Speech2TextAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -392,7 +392,7 @@ def forward( return outputs -# Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT +# copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->Speech2Text, MBART->SPEECH_TO_TEXT class Speech2TextDecoderLayer(nn.Module): def __init__(self, config: Speech2TextConfig): super().__init__() diff --git a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py index 6ba0d88e94f1..4936ae56366d 100644 --- a/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py +++ b/src/transformers/models/time_series_transformer/modeling_time_series_transformer.py @@ -22,6 +22,7 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, EncoderDecoderCache from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask from ...modeling_outputs import ( BaseModelOutput, @@ -32,10 +33,7 @@ ) from ...modeling_utils import PreTrainedModel from ...time_series_utils import NegativeBinomialOutput, NormalOutput, StudentTOutput -from ...utils import ( - auto_docstring, - logging, -) +from ...utils import auto_docstring, logging from .configuration_time_series_transformer import TimeSeriesTransformerConfig @@ -246,13 +244,16 @@ def _init_weight(self): self.weight = nn.Parameter(out, requires_grad=False) @torch.no_grad() - def forward(self, input_ids_shape: torch.Size, past_key_values_length: int = 0) -> torch.Tensor: + def forward( + self, input_ids_shape: torch.Size, past_key_values_length: int = 0, position_ids: Optional[torch.Tensor] = None + ) -> torch.Tensor: """`input_ids_shape` is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids_shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ) - return super().forward(positions) + if position_ids is None: + bsz, seq_len = input_ids_shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ) + return super().forward(position_ids) class TimeSeriesValueEmbedding(nn.Module): @@ -277,6 +278,7 @@ def __init__( bias: bool = True, is_causal: bool = False, config: Optional[TimeSeriesTransformerConfig] = None, + layer_idx: Optional[int] = None, ): super().__init__() self.embed_dim = embed_dim @@ -293,73 +295,74 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder self.is_causal = is_causal + self.layer_idx = layer_idx + if layer_idx is None and self.is_decoder: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will lead to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: """Input shape: Batch x Time x Channel""" # if key_value_states are provided this layer is used as a cross-attention layer # for the decoder is_cross_attention = key_value_states is not None - bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self.q_proj(hidden_states) * self.scaling - # get key, value proj - # `past_key_value[0].shape[2] == key_value_states.shape[1]` - # is checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - if ( - is_cross_attention - and past_key_value is not None - and past_key_value[0].shape[2] == key_value_states.shape[1] - ): + query_states = self.q_proj(hidden_states).view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + query_states = query_states * self.scaling + + if past_key_value is not None: + if isinstance(past_key_value, EncoderDecoderCache): + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + curr_past_key_value = past_key_value.cross_attention_cache + else: + curr_past_key_value = past_key_value.self_attention_cache + else: + curr_past_key_value = past_key_value + + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value is not None and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value[0] - value_states = past_key_value[1] - elif is_cross_attention: - # cross_attentions - key_states = self._shape(self.k_proj(key_value_states), -1, bsz) - value_states = self._shape(self.v_proj(key_value_states), -1, bsz) - elif past_key_value is not None: - # reuse k, v, self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - key_states = torch.cat([past_key_value[0], key_states], dim=2) - value_states = torch.cat([past_key_value[1], value_states], dim=2) + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - # self_attention - key_states = self._shape(self.k_proj(hidden_states), -1, bsz) - value_states = self._shape(self.v_proj(hidden_states), -1, bsz) - - if self.is_decoder: - # if cross_attention save Tuple(torch.Tensor, torch.Tensor) of all cross attention key/value_states. - # Further calls to cross_attention layer can then reuse all cross-attention - # key/value_states (first "if" case) - # if uni-directional self-attention (decoder) save Tuple(torch.Tensor, torch.Tensor) of - # all previous decoder key/value_states. Further calls to uni-directional self-attention - # can concat previous decoder key/value_states to current projected key/value_states (third "elif" case) - # if encoder bi-directional self-attention `past_key_value` is always `None` - past_key_value = (key_states, value_states) + key_states = self.k_proj(current_states) + value_states = self.v_proj(current_states) + key_states = key_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, -1, self.num_heads, self.head_dim).transpose(1, 2) + + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = curr_past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True proj_shape = (bsz * self.num_heads, -1, self.head_dim) - query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) + query_states = query_states.reshape(*proj_shape) key_states = key_states.reshape(*proj_shape) value_states = value_states.reshape(*proj_shape) @@ -373,10 +376,7 @@ def forward( ) if attention_mask is not None: - if attention_mask.size() != (bsz, 1, tgt_len, src_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" - ) + attention_mask = attention_mask[:, :, :, : key_states.shape[-2]] attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) @@ -425,7 +425,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartEncoderLayer with Bart->TimeSeriesTransformer, BART->TIME_SERIES_TRANSFORMER class TimeSeriesTransformerEncoderLayer(nn.Module): - def __init__(self, config: TimeSeriesTransformerConfig): + def __init__(self, config: TimeSeriesTransformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -434,6 +434,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, config=config, + layer_idx=layer_idx, ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout @@ -502,7 +503,7 @@ def forward( # Copied from transformers.models.bart.modeling_bart.BartDecoderLayer with Bart->TimeSeriesTransformer, with BART->TIME_SERIES_TRANSFORMER class TimeSeriesTransformerDecoderLayer(nn.Module): - def __init__(self, config: TimeSeriesTransformerConfig): + def __init__(self, config: TimeSeriesTransformerConfig, layer_idx: Optional[int] = None): super().__init__() self.embed_dim = config.d_model @@ -513,6 +514,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): is_decoder=True, is_causal=True, config=config, + layer_idx=layer_idx, ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -525,6 +527,7 @@ def __init__(self, config: TimeSeriesTransformerConfig): dropout=config.attention_dropout, is_decoder=True, config=config, + layer_idx=layer_idx, ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) @@ -539,9 +542,10 @@ def forward( encoder_attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, + cache_position: Optional[torch.Tensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -560,47 +564,42 @@ def forward( output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ residual = hidden_states # Self Attention - # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 - self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None - # add present self-attn cache to positions 1,2 of present_key_value tuple - hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states, self_attn_weights, past_key_value = self.self_attn( hidden_states=hidden_states, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.self_attn_layer_norm(hidden_states) # Cross-Attention Block - cross_attn_present_key_value = None cross_attn_weights = None if encoder_hidden_states is not None: residual = hidden_states - # cross_attn cached key/values tuple is at positions 3,4 of present_key_value tuple - cross_attn_past_key_value = past_key_value[-2:] if past_key_value is not None else None - hidden_states, cross_attn_weights, cross_attn_present_key_value = self.encoder_attn( + hidden_states, cross_attn_weights, past_key_value = self.encoder_attn( hidden_states=hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, + past_key_value=past_key_value, output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) hidden_states = residual + hidden_states hidden_states = self.encoder_attn_layer_norm(hidden_states) - # add cross-attn to positions 3,4 of present_key_value tuple - present_key_value = present_key_value + cross_attn_present_key_value - # Fully Connected residual = hidden_states hidden_states = self.activation_fn(self.fc1(hidden_states)) @@ -616,7 +615,7 @@ def forward( outputs += (self_attn_weights, cross_attn_weights) if use_cache: - outputs += (present_key_value,) + outputs += (past_key_value,) return outputs @@ -799,7 +798,9 @@ def __init__(self, config: TimeSeriesTransformerConfig): self.embed_positions = TimeSeriesSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model ) - self.layers = nn.ModuleList([TimeSeriesTransformerDecoderLayer(config) for _ in range(config.decoder_layers)]) + self.layers = nn.ModuleList( + [TimeSeriesTransformerDecoderLayer(config, layer_idx=i) for i in range(config.decoder_layers)] + ) self.layernorm_embedding = nn.LayerNorm(config.d_model) self.gradient_checkpointing = False @@ -819,6 +820,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPastAndCrossAttentions]: r""" Args: @@ -876,6 +878,9 @@ def forward( for more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -885,9 +890,22 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict input_shape = inputs_embeds.size()[:-1] + # initialize `past_key_values` + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) - # past_key_values_length - past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0 + past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device + ) attention_mask = _prepare_4d_causal_attention_mask( attention_mask, input_shape, inputs_embeds, past_key_values_length @@ -916,7 +934,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None - next_decoder_cache = () if use_cache else None + next_decoder_cache = None # check if head_mask/cross_attn_head_mask has a correct number of layers specified if desired for attn_mask, mask_name in zip([head_mask, cross_attn_head_mask], ["head_mask", "cross_attn_head_mask"]): @@ -936,8 +954,6 @@ def forward( if dropout_probability < self.layerdrop: continue - past_key_value = past_key_values[idx] if past_key_values is not None else None - if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, @@ -961,14 +977,15 @@ def forward( cross_attn_layer_head_mask=( cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None ), - past_key_value=past_key_value, + past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: - next_decoder_cache += (layer_outputs[3 if output_attentions else 1],) + next_decoder_cache = layer_outputs[3 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) @@ -981,6 +998,9 @@ def forward( all_hidden_states += (hidden_states,) next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v @@ -1155,6 +1175,7 @@ def forward( output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Seq2SeqTSModelOutput, Tuple]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): @@ -1324,6 +1345,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1406,6 +1428,7 @@ def forward( output_attentions: Optional[bool] = None, use_cache: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Seq2SeqTSModelOutput, Tuple]: r""" past_values (`torch.FloatTensor` of shape `(batch_size, sequence_length)` or `(batch_size, sequence_length, input_size)`): @@ -1577,6 +1600,7 @@ def forward( output_attentions=output_attentions, use_cache=use_cache, return_dict=return_dict, + cache_position=cache_position, ) prediction_loss = None diff --git a/src/transformers/models/trocr/modeling_trocr.py b/src/transformers/models/trocr/modeling_trocr.py index e2c77f577991..152243da0845 100644 --- a/src/transformers/models/trocr/modeling_trocr.py +++ b/src/transformers/models/trocr/modeling_trocr.py @@ -24,7 +24,10 @@ from ...activations import ACT2FN from ...generation import GenerationMixin -from ...modeling_attn_mask_utils import _prepare_4d_attention_mask, _prepare_4d_causal_attention_mask +from ...modeling_attn_mask_utils import ( + _prepare_4d_attention_mask, + _prepare_4d_causal_attention_mask, +) from ...modeling_outputs import BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel from ...utils import auto_docstring, logging @@ -46,15 +49,18 @@ def __init__(self, num_embeddings: int, embedding_dim: int): self.offset = 2 super().__init__(num_embeddings + self.offset, embedding_dim) - def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0): + def forward(self, input_ids: torch.Tensor, past_key_values_length: int = 0, position_ids: torch.Tensor = None): """`input_ids' shape is expected to be [bsz x seqlen].""" - bsz, seq_len = input_ids.shape[:2] - positions = torch.arange( - past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device - ).expand(bsz, -1) + if position_ids is None: + bsz, seq_len = input_ids.shape[:2] + position_ids = torch.arange( + past_key_values_length, past_key_values_length + seq_len, dtype=torch.long, device=self.weight.device + ).expand(bsz, -1) + else: + position_ids = position_ids.unsqueeze(0) - return super().forward(positions + self.offset) + return super().forward(position_ids + self.offset) # Copied from transformers.models.bart.modeling_bart.BartScaledWordEmbedding with Bart->TrOCR diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 4ad2e9144490..7eb38da9acbe 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -465,7 +465,7 @@ def forward(self, hidden_states): return hidden_states, norm_hidden_states -# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->Wav2Vec2 +# Copied from transformers.models.hubert.modeling_hubert.HubertAttention with Hubert->Wav2Vec2 class Wav2Vec2Attention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -624,7 +624,7 @@ def forward( return attn_output, attn_weights_reshaped, past_key_value -# Copied from transformers.models.bart.modeling_bart.BartFlashAttention2 with Bart->Wav2Vec2 +# Copied from transformers.models.hubert.modeling_hubert.HubertFlashAttention2 with Hubert->Wav2Vec2 class Wav2Vec2FlashAttention2(Wav2Vec2Attention): """ Wav2Vec2 flash attention module. This module inherits from `Wav2Vec2Attention` as the weights of the module stays @@ -748,7 +748,7 @@ def forward( class Wav2Vec2SdpaAttention(Wav2Vec2Attention): - # Copied from transformers.models.bart.modeling_bart.BartSdpaAttention.forward with Bart->Wav2Vec2 + # Copied from transformers.models.hubert.modeling_hubert.HubertSdpaAttention.forward with Hubert->Wav2Vec2 def forward( self, hidden_states: torch.Tensor, diff --git a/src/transformers/models/whisper/modeling_whisper.py b/src/transformers/models/whisper/modeling_whisper.py index 8d354f4f92cd..2bf387de82a6 100644 --- a/src/transformers/models/whisper/modeling_whisper.py +++ b/src/transformers/models/whisper/modeling_whisper.py @@ -259,10 +259,6 @@ def __init__( self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - # Copied from transformers.models.bart.modeling_bart.BartAttention._shape with BART->whisper - def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): - return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() - def forward( self, hidden_states: torch.Tensor, @@ -281,7 +277,9 @@ def forward( bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self._shape(self.q_proj(hidden_states) * self.scaling, tgt_len, bsz) + query_states = self.q_proj(hidden_states) * self.scaling + query_states = query_states.view(bsz, tgt_len, self.num_heads, self.head_dim) + query_states = query_states.transpose(1, 2).contiguous() if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) @@ -299,8 +297,10 @@ def forward( key_states = past_key_value.key_cache[self.layer_idx] value_states = past_key_value.value_cache[self.layer_idx] else: - key_states = self._shape(self.k_proj(current_states), -1, bsz) - value_states = self._shape(self.v_proj(current_states), -1, bsz) + key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) + value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None @@ -398,8 +398,10 @@ def forward( key_states = past_key_value.key_cache[self.layer_idx] value_states = past_key_value.value_cache[self.layer_idx] else: - key_states = self._shape(self.k_proj(current_states), -1, bsz) - value_states = self._shape(self.v_proj(current_states), -1, bsz) + key_states = self.k_proj(current_states).view(bsz, tgt_len, self.num_heads, self.head_dim) + value_states = self.v_proj(current_states).view(bsz, tgt_len, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None @@ -496,7 +498,8 @@ def forward( bsz, tgt_len, _ = hidden_states.size() # get query proj - query_states = self._shape(self.q_proj(hidden_states), tgt_len, bsz) + query_states = self.q_proj(hidden_states).view(bsz, tgt_len, self.num_heads, self.head_dim) + query_states = query_states.transpose(1, 2).contiguous() if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) @@ -514,8 +517,10 @@ def forward( key_states = past_key_value.key_cache[self.layer_idx] value_states = past_key_value.value_cache[self.layer_idx] else: - key_states = self._shape(self.k_proj(current_states), -1, bsz) - value_states = self._shape(self.v_proj(current_states), -1, bsz) + key_states = self.k_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) + value_states = self.v_proj(current_states).view(bsz, -1, self.num_heads, self.head_dim) + key_states = key_states.transpose(1, 2).contiguous() + value_states = value_states.transpose(1, 2).contiguous() if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None diff --git a/src/transformers/models/xglm/modeling_xglm.py b/src/transformers/models/xglm/modeling_xglm.py index 1b9c14ddcaf1..dc1e268cbe4a 100755 --- a/src/transformers/models/xglm/modeling_xglm.py +++ b/src/transformers/models/xglm/modeling_xglm.py @@ -282,7 +282,7 @@ def __init__(self, config: XGLMConfig): self.fc2 = nn.Linear(config.ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) - # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward + # copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer.forward def forward( self, hidden_states: torch.Tensor, diff --git a/tests/generation/test_utils.py b/tests/generation/test_utils.py index aafa003ee670..0bf77c60d204 100644 --- a/tests/generation/test_utils.py +++ b/tests/generation/test_utils.py @@ -2144,6 +2144,7 @@ def test_generate_compile_model_forward(self): compile_config._compile_all_devices = True # force compilation (e.g. fast CI, CPU) generation_kwargs = { + "use_cache": True, "do_sample": False, "max_new_tokens": 5, "return_dict_in_generate": True, diff --git a/tests/models/bart/test_modeling_bart.py b/tests/models/bart/test_modeling_bart.py index 6d80b8413108..34fbcd80f9bf 100644 --- a/tests/models/bart/test_modeling_bart.py +++ b/tests/models/bart/test_modeling_bart.py @@ -87,7 +87,7 @@ def __init__( hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=20, + max_position_embeddings=50, eos_token_id=2, pad_token_id=1, bos_token_id=0, @@ -1164,8 +1164,7 @@ def test_cnn_summarization_same_as_fairseq(self): [FRANCE_ARTICLE, SHORTER_ARTICLE, IRAN_ARTICLE, ARTICLE_SUBWAY], max_length=1024, padding="max_length", - truncation_strategy="only_first", - truncation=True, + truncation="only_first", return_tensors="pt", ) @@ -1301,7 +1300,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -1365,6 +1364,7 @@ def prepare_config_and_inputs(self): decoder_start_token_id=self.decoder_start_token_id, max_position_embeddings=self.max_position_embeddings, is_encoder_decoder=self.is_encoder_decoder, + forced_eos_token_id=None, ) return ( @@ -1465,9 +1465,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py index ecdde66786cf..34e633757acd 100644 --- a/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py +++ b/tests/models/bigbird_pegasus/test_modeling_bigbird_pegasus.py @@ -611,7 +611,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -767,7 +767,7 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, use_cache=True)["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/biogpt/test_modeling_biogpt.py b/tests/models/biogpt/test_modeling_biogpt.py index e359c1d6d75b..232f7176b23c 100644 --- a/tests/models/biogpt/test_modeling_biogpt.py +++ b/tests/models/biogpt/test_modeling_biogpt.py @@ -320,6 +320,7 @@ def test_batch_generation(self): # Define PAD Token = EOS Token = 50256 tokenizer.pad_token = tokenizer.eos_token model.config.pad_token_id = model.config.eos_token_id + model.generation_config.pad_token_id = model.generation_config.eos_token_id # use different length sentences to test batching sentences = [ @@ -333,10 +334,11 @@ def test_batch_generation(self): outputs = model.generate( input_ids=input_ids, attention_mask=inputs["attention_mask"].to(torch_device), + max_new_tokens=10, ) inputs_non_padded = tokenizer(sentences[0], return_tensors="pt").input_ids.to(torch_device) - output_non_padded = model.generate(input_ids=inputs_non_padded) + output_non_padded = model.generate(input_ids=inputs_non_padded, max_new_tokens=10) num_paddings = inputs_non_padded.shape[-1] - inputs["attention_mask"][-1].long().sum().item() inputs_padded = tokenizer(sentences[1], return_tensors="pt").input_ids.to(torch_device) diff --git a/tests/models/blenderbot/test_modeling_blenderbot.py b/tests/models/blenderbot/test_modeling_blenderbot.py index 945506eeb3ca..83a5c73ff7b7 100644 --- a/tests/models/blenderbot/test_modeling_blenderbot.py +++ b/tests/models/blenderbot/test_modeling_blenderbot.py @@ -356,7 +356,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -500,9 +500,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py index 8bdf116647a6..5e9163376a2c 100644 --- a/tests/models/blenderbot_small/test_modeling_blenderbot_small.py +++ b/tests/models/blenderbot_small/test_modeling_blenderbot_small.py @@ -366,7 +366,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -509,9 +509,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values, attention_mask=attn_mask)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, past_key_values=past_key_values, attention_mask=attn_mask, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/informer/test_modeling_informer.py b/tests/models/informer/test_modeling_informer.py index cb197631975e..92fddb499891 100644 --- a/tests/models/informer/test_modeling_informer.py +++ b/tests/models/informer/test_modeling_informer.py @@ -170,8 +170,9 @@ def check_encoder_decoder_model_standalone(self, config, inputs_dict): embed_positions = InformerSinusoidalPositionalEmbedding( config.context_length + config.prediction_length, config.d_model - ).to(torch_device) + ) embed_positions._init_weight() + embed_positions = embed_positions.to(torch_device) self.parent.assertTrue(torch.equal(model.encoder.embed_positions.weight, embed_positions.weight)) self.parent.assertTrue(torch.equal(model.decoder.embed_positions.weight, embed_positions.weight)) diff --git a/tests/models/m2m_100/test_modeling_m2m_100.py b/tests/models/m2m_100/test_modeling_m2m_100.py index 91bda0d8b547..c264de826c20 100644 --- a/tests/models/m2m_100/test_modeling_m2m_100.py +++ b/tests/models/m2m_100/test_modeling_m2m_100.py @@ -82,7 +82,7 @@ def __init__( attention_probs_dropout_prob=0.1, encoder_layerdrop=0.0, decoder_layerdrop=0.0, - max_position_embeddings=20, + max_position_embeddings=50, eos_token_id=2, pad_token_id=1, bos_token_id=0, @@ -426,7 +426,7 @@ def test_flash_attn_2_seq_to_seq_generation(self): Overwriting the common test as the test is flaky on tiny models """ model = M2M100ForConditionalGeneration.from_pretrained( - "facebook/m2m100_418M", attn_implementation="flash_attention_2" + "facebook/m2m100_418M", torch_dtype=torch.float16, attn_implementation="flash_attention_2" ).to(torch_device) tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M", src_lang="fr", tgt_lang="en") diff --git a/tests/models/marian/test_modeling_marian.py b/tests/models/marian/test_modeling_marian.py index 52f2232ee69a..a53627852f5d 100644 --- a/tests/models/marian/test_modeling_marian.py +++ b/tests/models/marian/test_modeling_marian.py @@ -99,7 +99,7 @@ def __init__( hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=20, + max_position_embeddings=100, eos_token_id=2, pad_token_id=1, bos_token_id=0, @@ -653,7 +653,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=100, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -796,9 +796,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/mbart/test_modeling_mbart.py b/tests/models/mbart/test_modeling_mbart.py index 9ec4d56a1a1c..3b328cfe9e01 100644 --- a/tests/models/mbart/test_modeling_mbart.py +++ b/tests/models/mbart/test_modeling_mbart.py @@ -538,7 +538,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -681,9 +681,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/pegasus/test_modeling_pegasus.py b/tests/models/pegasus/test_modeling_pegasus.py index d0b86d5fa5b5..2b7a127d4828 100644 --- a/tests/models/pegasus/test_modeling_pegasus.py +++ b/tests/models/pegasus/test_modeling_pegasus.py @@ -542,9 +542,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/pegasus_x/test_modeling_pegasus_x.py b/tests/models/pegasus_x/test_modeling_pegasus_x.py index a1439ebaabc5..2ffdc4636756 100644 --- a/tests/models/pegasus_x/test_modeling_pegasus_x.py +++ b/tests/models/pegasus_x/test_modeling_pegasus_x.py @@ -78,7 +78,7 @@ def __init__( hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, - max_position_embeddings=20, + max_position_embeddings=50, eos_token_id=2, pad_token_id=1, bos_token_id=0, @@ -676,7 +676,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -819,7 +819,7 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids)["last_hidden_state"] - output_from_past = model(next_tokens, past_key_values=past_key_values)["last_hidden_state"] + output_from_past = model(next_tokens, past_key_values=past_key_values, use_cache=True)["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/models/plbart/test_modeling_plbart.py b/tests/models/plbart/test_modeling_plbart.py index 7255f2b7c530..569a63d435e5 100644 --- a/tests/models/plbart/test_modeling_plbart.py +++ b/tests/models/plbart/test_modeling_plbart.py @@ -496,7 +496,7 @@ def __init__( decoder_layers=2, encoder_attention_heads=4, decoder_attention_heads=4, - max_position_embeddings=30, + max_position_embeddings=50, is_encoder_decoder=False, pad_token_id=0, bos_token_id=1, @@ -634,9 +634,9 @@ def create_and_check_decoder_model_attention_mask_past( # get two different outputs output_from_no_past = model(next_input_ids, attention_mask=attn_mask)["last_hidden_state"] - output_from_past = model(next_tokens, attention_mask=attn_mask, past_key_values=past_key_values)[ - "last_hidden_state" - ] + output_from_past = model( + next_tokens, attention_mask=attn_mask, past_key_values=past_key_values, use_cache=True + )["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1,), output_from_past.shape[-1]).item() diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1e5714d36631..338047e95a17 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -735,6 +735,7 @@ def check_determinism(first, second): model = model_class(config) model.to(torch_device) model.eval() + print(model_class) with torch.no_grad(): first = model(**self._prepare_for_class(inputs_dict, model_class))[0] second = model(**self._prepare_for_class(inputs_dict, model_class))[0] @@ -4130,6 +4131,9 @@ def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): if "position_ids" not in inspect.signature(model.forward).parameters: self.skipTest("Model does not support position_ids") + if "position_ids" not in inspect.signature(model.forward).parameters: + continue # this model doesn't accept position ids as input + with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) @@ -4268,7 +4272,16 @@ def test_flash_attn_2_from_config(self): dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device) dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device) - _ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask) + if config.is_encoder_decoder: + _ = fa2_model( + input_ids=dummy_input, + attention_mask=dummy_attention_mask, + decoder_input_ids=dummy_input.clone(), + decoder_attention_mask=dummy_attention_mask.clone(), + ) + else: + _ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask) + with tempfile.TemporaryDirectory() as tmpdirname: fa2_model.save_pretrained(tmpdirname) model_from_pretrained = model_class.from_pretrained(tmpdirname) @@ -4327,8 +4340,10 @@ def test_custom_4d_attention_mask(self): set_config_for_less_flaky_test(config) if getattr(config, "sliding_window", 0) is not None and getattr(config, "sliding_window", 0) > 0: self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test") - model = model_class(config).to(device=torch_device, dtype=torch.float32) + model = model_class(config).to(device=torch_device, dtype=torch.float32).eval() set_model_for_less_flaky_test(model) + if "position_ids" not in inspect.signature(model.forward).parameters: + continue # model doesn't accept position ids and probably has special way to model positions if "position_ids" not in inspect.signature(model.forward).parameters: continue # this model doesn't accept position ids as input