diff --git a/docs/source/en/perf_infer_gpu_one.md b/docs/source/en/perf_infer_gpu_one.md index a81fc4938138..b2deb1c12e04 100644 --- a/docs/source/en/perf_infer_gpu_one.md +++ b/docs/source/en/perf_infer_gpu_one.md @@ -208,6 +208,7 @@ For now, Transformers supports SDPA inference and training for the following arc * [OLMo](https://huggingface.co/docs/transformers/model_doc/olmo#transformers.OlmoModel) * [PaliGemma](https://huggingface.co/docs/transformers/model_doc/paligemma#transformers.PaliGemmaForConditionalGeneration) * [Phi](https://huggingface.co/docs/transformers/model_doc/phi#transformers.PhiModel) +* [Phi3](https://huggingface.co/docs/transformers/model_doc/phi3#transformers.Phi3Model) * [Idefics](https://huggingface.co/docs/transformers/model_doc/idefics#transformers.IdeficsModel) * [Whisper](https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperModel) * [Mistral](https://huggingface.co/docs/transformers/model_doc/mistral#transformers.MistralModel) diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 7cef177fef99..777dce4849d3 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -100,8 +100,6 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings @torch.no_grad() def forward(self, x, position_ids): diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 50ec015e521c..158fcae42272 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -98,8 +98,6 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, s self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - # For BC we register cos and sin cached - self.max_seq_len_cached = max_position_embeddings @torch.no_grad() def forward(self, x, position_ids): diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index a2c3793c0119..3f5b7fefc0f5 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -26,11 +26,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...modeling_attn_mask_utils import ( - _prepare_4d_causal_attention_mask, - _prepare_4d_causal_attention_mask_for_sdpa, -) +from ...cache_utils import Cache, DynamicCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -75,88 +72,63 @@ def _get_unpad_data(attention_mask): ) -# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Phi +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Phi class PhiRotaryEmbedding(nn.Module): - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): super().__init__() - + self.scaling_factor = scaling_factor self.dim = dim self.max_position_embeddings = max_position_embeddings self.base = base inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) self.register_buffer("inv_freq", inv_freq, persistent=False) - # Build here to make `torch.jit.trace` work. - self._set_cos_sin_cache( - seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() - ) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) - - def forward(self, x, seq_len=None): + @torch.no_grad() + def forward(self, x, position_ids): # x: [bs, num_attention_heads, seq_len, head_size] - if seq_len > self.max_seq_len_cached: - self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) - - return ( - self.cos_cached[:seq_len].to(dtype=x.dtype), - self.sin_cached[:seq_len].to(dtype=x.dtype), - ) - - -# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->Phi + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Phi class PhiLinearScalingRotaryEmbedding(PhiRotaryEmbedding): """PhiRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) - t = t / self.scaling_factor - - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + def forward(self, x, position_ids): + # difference to the original RoPE: a scaling factor is aplied to the position ids + position_ids = position_ids.float() / self.scaling_factor + cos, sin = super().forward(x, position_ids) + return cos, sin -# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->Phi +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Phi class PhiDynamicNTKScalingRotaryEmbedding(PhiRotaryEmbedding): """PhiRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" - def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0): - self.scaling_factor = scaling_factor - super().__init__(dim, max_position_embeddings, base, device) - - def _set_cos_sin_cache(self, seq_len, device, dtype): - self.max_seq_len_cached = seq_len - + def forward(self, x, position_ids): + # difference to the original RoPE: inv_freq is recomputed when the sequence length > original length + seq_len = torch.max(position_ids) + 1 if seq_len > self.max_position_embeddings: base = self.base * ( (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1) ) ** (self.dim / (self.dim - 2)) - inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) - self.register_buffer("inv_freq", inv_freq, persistent=False) - - t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + inv_freq = 1.0 / ( + base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(x.device) / self.dim) + ) + self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: this may break with compilation - freqs = torch.outer(t, self.inv_freq) - # Different from paper, but it uses a different permutation in order to obtain the same calculation - emb = torch.cat((freqs, freqs), dim=-1) - self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) - self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + cos, sin = super().forward(x, position_ids) + return cos, sin # Copied from transformers.models.llama.modeling_llama.rotate_half @@ -167,8 +139,8 @@ def rotate_half(x): return torch.cat((-x2, x1), dim=-1) -# Copied from transformers.models.mixtral.modeling_mixtral.apply_rotary_pos_emb -def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): +# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. Args: @@ -176,9 +148,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): k (`torch.Tensor`): The key tensor. cos (`torch.Tensor`): The cosine part of the rotary embedding. sin (`torch.Tensor`): The sine part of the rotary embedding. - position_ids (`torch.Tensor`): - The position indices of the tokens corresponding to the query and key tensors. For example, this can be - used to pass offsetted position ids when working with a KV-cache. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. unsqueeze_dim (`int`, *optional*, defaults to 1): The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note @@ -189,8 +160,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): Returns: `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. """ - cos = cos[position_ids].unsqueeze(unsqueeze_dim) - sin = sin[position_ids].unsqueeze(unsqueeze_dim) + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) q_embed = (q * cos) + (rotate_half(q) * sin) k_embed = (k * cos) + (rotate_half(k) * sin) return q_embed, k_embed @@ -299,6 +270,7 @@ def _init_rope(self): else: raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + # Adapted from transformers.models.llama.modeling_llama.LlamaAttention.forward def forward( self, hidden_states: torch.Tensor, @@ -307,6 +279,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: bsz, q_len, _ = hidden_states.size() @@ -322,16 +295,8 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + past_key_value = getattr(self, "past_key_value", past_key_value) + cos, sin = self.rotary_emb(value_states, position_ids) # Partial rotary embedding query_rot, query_pass = ( @@ -350,7 +315,8 @@ def forward( key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) @@ -361,18 +327,9 @@ def forward( query_states.to(torch.float32), key_states.to(torch.float32).transpose(2, 3) ) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) @@ -413,6 +370,7 @@ def __init__(self, *args, **kwargs): # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + # Adapted from transformers.models.llama.modeling_llama.LlamaFlashAttention2.forward def forward( self, hidden_states: torch.Tensor, @@ -421,9 +379,13 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, - **kwargs, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - # PhiFlashAttention2 attention does not support output_attentions + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) output_attentions = False @@ -444,29 +406,14 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) - # Partial rotary embedding - query_rot, query_pass = ( - query_states[..., : self.rotary_emb.dim], - query_states[..., self.rotary_emb.dim :], - ) - key_rot, key_pass = ( - key_states[..., : self.rotary_emb.dim], - key_states[..., self.rotary_emb.dim :], - ) - # [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor] - query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) - - # [batch_size, seq_length, num_heads, head_dim] - query_states = torch.cat((query_rot, query_pass), dim=-1) - key_states = torch.cat((key_rot, key_pass), dim=-1) + past_key_value = getattr(self, "past_key_value", past_key_value) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache @@ -634,6 +581,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -650,6 +598,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) bsz, q_len, _ = hidden_states.size() @@ -666,16 +615,7 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + cos, sin = self.rotary_emb(value_states, position_ids) # Partial rotary embedding query_rot, query_pass = ( @@ -693,36 +633,44 @@ def forward( query_states = torch.cat((query_rot, query_pass), dim=-1) key_states = torch.cat((key_rot, key_pass), dim=-1) + # In case static cache is used, it is an instance attribute. + past_key_value = getattr(self, "past_key_value", past_key_value) + if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim} + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + # SDPA with memory-efficient backend is broken in torch==2.1.2 when using non-contiguous inputs and a custom # attn_mask, so we need to call `.contiguous()` here. This was fixed in torch==2.2.0. # Reference: https://github.com/pytorch/pytorch/issues/112577 - if self.require_contiguous_qkv and query_states.device.type == "cuda" and attention_mask is not None: + if self.require_contiguous_qkv and query_states.device.type == "cuda" and causal_mask is not None: query_states = query_states.contiguous() key_states = key_states.contiguous() value_states = value_states.contiguous() # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. - is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = attn_output.view(bsz, q_len, self.hidden_size) attn_output = self.dense(attn_output) @@ -752,6 +700,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -783,6 +732,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) attn_outputs = self.resid_dropout(attn_outputs) @@ -829,6 +779,7 @@ class PhiPreTrainedModel(PreTrainedModel): _supports_flash_attn_2 = True _supports_sdpa = True _supports_cache_class = True + _supports_static_cache = True def _init_weights(self, module): std = self.config.initializer_range @@ -961,6 +912,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -970,59 +922,36 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = 0 - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) - position_ids = position_ids.unsqueeze(0) + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - inputs_embeds = self.embed_dropout(inputs_embeds) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) - # Attention mask. - if self._use_flash_attention_2: - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - elif self._use_sdpa and not output_attentions: - attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - ) - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) hidden_states = inputs_embeds @@ -1039,19 +968,22 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, - past_key_values, output_attentions, + use_cache, + past_key_values, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, - past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + past_key_value=past_key_values, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -1068,9 +1000,10 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -1080,6 +1013,87 @@ def forward( attentions=all_self_attns, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_length() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + class PhiForCausalLM(PhiPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] @@ -1132,6 +1146,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1176,6 +1191,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1207,21 +1223,31 @@ def forward( attentions=outputs.attentions, ) - # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, ): past_length = 0 if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1250,13 +1276,23 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, } ) diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index e14785bd1f8b..be7a55cc3254 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -26,8 +26,8 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import Cache, DynamicCache -from ...modeling_attn_mask_utils import _prepare_4d_causal_attention_mask +from ...cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutputWithPast, CausalLMOutputWithPast, @@ -41,6 +41,7 @@ add_start_docstrings_to_model_forward, is_flash_attn_2_available, is_flash_attn_greater_or_equal_2_10, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -331,8 +332,12 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") + if not is_torchdynamo_compiling: + logger.warning_once( + "You are not running the flash-attention implementation, expect numerical differences." + ) bsz, q_len, _ = hidden_states.size() @@ -346,21 +351,11 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - if self.layer_idx is None: - raise ValueError( - f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " - "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " - "with a layer index." - ) - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -369,18 +364,9 @@ def forward( attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) - if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): - raise ValueError( - f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" - f" {attn_weights.size()}" - ) - - if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) - attn_weights = attn_weights + attention_mask + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask # upcast attention to fp32 attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) @@ -429,14 +415,16 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # Phi3FlashAttention2 attention does not support output_attentions if not _flash_supports_window_size: - logger.warning_once( - "The current flash attention version does not support sliding window attention. Please use `attn_implementation='eager'` or upgrade flash-attn library." + raise ValueError( + "The current flash attention version does not support sliding window attention. " + "Please use `attn_implementation='eager'` or upgrade flash-attn library." ) - raise ValueError("The current flash attention version does not support sliding window attention.") output_attentions = False @@ -465,11 +453,8 @@ def forward( ) kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - # Because the input can be padded, the absolute sequence length depends on the max position id. - rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=rotary_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) use_sliding_windows = ( _flash_supports_window_size @@ -477,7 +462,8 @@ def forward( and kv_seq_len > self.config.sliding_window ) - if past_key_value is not None: + # We cannot do sliding window with StaticCache bacause last tokens in cache are usually placeholders filled with zeros + if past_key_value is not None and not isinstance(past_key_value, StaticCache): # Activate slicing cache only if the config has a value `sliding_windows` attribute cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 if ( @@ -500,10 +486,10 @@ def forward( ) if attention_mask is not None: - attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = attention_mask[:, :, :, slicing_tokens:] attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -527,11 +513,12 @@ def forward( else: target_dtype = self.qkv_proj.weight.dtype - logger.warning_once( - f"The input hidden states seems to be silently casted in float32, this might be related to" - f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" - f" {target_dtype}." - ) + if not is_torchdynamo_compiling: + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) query_states = query_states.to(target_dtype) key_states = key_states.to(target_dtype) @@ -706,7 +693,6 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query # copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 -# TODO @Arthur no longer copied from LLama after static cache class Phi3SdpaAttention(Phi3Attention): """ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from @@ -723,6 +709,7 @@ def forward( past_key_value: Optional[Cache] = None, output_attentions: bool = False, use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: if output_attentions: # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. @@ -751,25 +738,19 @@ def forward( key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) - kv_seq_len = key_states.shape[-2] - if past_key_value is not None: - kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) - cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) - - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) key_states = repeat_kv(key_states, self.num_key_value_groups) value_states = repeat_kv(value_states, self.num_key_value_groups) + causal_mask = attention_mask if attention_mask is not None: - if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): - raise ValueError( - f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" - ) + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, # Reference: https://github.com/pytorch/pytorch/issues/112577. @@ -781,13 +762,13 @@ def forward( # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. - is_causal = True if self.is_causal and attention_mask is None and q_len > 1 else False + is_causal = True if self.is_causal and causal_mask is None and q_len > 1 else False attn_output = torch.nn.functional.scaled_dot_product_attention( query_states, key_states, value_states, - attn_mask=attention_mask, + attn_mask=causal_mask, dropout_p=self.attention_dropout if self.training else 0.0, is_causal=is_causal, ) @@ -829,6 +810,8 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: @@ -860,6 +843,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = residual + self.resid_attn_dropout(attn_outputs) @@ -908,8 +892,9 @@ class Phi3PreTrainedModel(PreTrainedModel): _no_split_modules = ["Phi3DecoderLayer"] _skip_keys_device_placement = "past_key_values" _supports_flash_attn_2 = True - _supports_sdpa = False + _supports_sdpa = True _supports_cache_class = True + _supports_static_cache = True _version = "0.0.5" @@ -1042,6 +1027,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -1051,64 +1037,36 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - # retrieve input_ids and inputs_embeds - if input_ids is not None and inputs_embeds is not None: - raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: - batch_size, seq_length = input_ids.shape[:2] - elif inputs_embeds is not None: - batch_size, seq_length = inputs_embeds.shape[:2] - else: - raise ValueError("You have to specify either input_ids or inputs_embeds") - - past_key_values_length = 0 - - if self.gradient_checkpointing and self.training: - if use_cache: - logger.warning_once( - "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." - ) - use_cache = False - - if use_cache: - use_legacy_cache = not isinstance(past_key_values, Cache) - if use_legacy_cache: - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_key_values_length = past_key_values.get_usable_length(seq_length) + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError( + "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" + ) - if position_ids is None: - device = input_ids.device if input_ids is not None else inputs_embeds.device - position_ids = torch.arange( - past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + if self.gradient_checkpointing and self.training and use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) - position_ids = position_ids.unsqueeze(0).view(-1, seq_length) - else: - position_ids = position_ids.view(-1, seq_length).long() + use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - if attention_mask is not None and self._attn_implementation == "flash_attention_2" and use_cache: - is_padding_right = attention_mask[:, -1].sum().item() != batch_size - if is_padding_right: - raise ValueError( - "You are attempting to perform batched generation with padding_side='right'" - " this may lead to unexpected behaviour for Flash Attention version of Phi3. Make sure to " - " call `tokenizer.padding_side = 'left'` before tokenizing the input. " - ) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = True + past_key_values = DynamicCache.from_legacy_cache(past_key_values) - if self._attn_implementation == "flash_attention_2": - # 2d mask is passed through the layers - attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None - else: - # 4d mask is passed through the layers - attention_mask = _prepare_4d_causal_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) hidden_states = inputs_embeds @@ -1125,20 +1083,22 @@ def forward( layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, - attention_mask, + causal_mask, position_ids, past_key_values, output_attentions, use_cache, + cache_position, ) else: layer_outputs = decoder_layer( hidden_states, - attention_mask=attention_mask, + attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, + cache_position=cache_position, ) hidden_states = layer_outputs[0] @@ -1155,9 +1115,9 @@ def forward( if output_hidden_states: all_hidden_states += (hidden_states,) - next_cache = None - if use_cache: - next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( @@ -1167,6 +1127,102 @@ def forward( attentions=all_self_attns, ) + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache + if using_sliding_window_cache: + target_length = max(sequence_length, self.config.sliding_window) + # StaticCache + elif using_static_cache: + target_length = past_key_values.get_max_length() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + if attention_mask is not None and attention_mask.dim() == 4: + # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing + if attention_mask.max() != 0: + raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`") + causal_mask = attention_mask + else: + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if self.config.sliding_window is not None: + if not using_sliding_window_cache or sequence_length > self.config.sliding_window: + exclude_mask |= torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - self.config.sliding_window + ) + causal_mask *= exclude_mask + causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.dim() == 2: + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + class Phi3ForCausalLM(Phi3PreTrainedModel): _tied_weights_keys = ["lm_head.weight"] @@ -1220,6 +1276,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple, CausalLMOutputWithPast]: r""" Args: @@ -1264,6 +1321,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, + cache_position=cache_position, ) hidden_states = outputs[0] @@ -1295,21 +1353,31 @@ def forward( attentions=outputs.attentions, ) - # Copied from transformers.models.persimmon.modeling_persimmon.PersimmonForCausalLM.prepare_inputs_for_generation + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( - self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, ): past_length = 0 if past_key_values is not None: # Past key values are always initialized with a `Cache` object -> no need for if-else anymore - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1338,13 +1406,23 @@ def prepare_inputs_for_generation( if inputs_embeds is not None and past_length == 0: model_inputs = {"inputs_embeds": inputs_embeds} else: - model_inputs = {"input_ids": input_ids} + # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise + # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114 + # TODO: use `next_tokens` directly instead. + model_inputs = {"input_ids": input_ids.contiguous()} + + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] model_inputs.update( { "position_ids": position_ids, + "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, } ) diff --git a/tests/models/phi/test_modeling_phi.py b/tests/models/phi/test_modeling_phi.py index f395b70c1ee2..9735b544f29e 100644 --- a/tests/models/phi/test_modeling_phi.py +++ b/tests/models/phi/test_modeling_phi.py @@ -18,6 +18,7 @@ import unittest import pytest +from packaging import version from parameterized import parameterized from transformers import PhiConfig, is_torch_available, set_seed @@ -397,7 +398,7 @@ def test_model_rope_scaling_from_config(self, scaling_type): # The output should be different for long inputs self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5)) - # Copied from tests.models.falcon.test_modeling_falcon.FalconModelTest.test_model_rope_scaling with Falcon->Phi + # Copied from tests.models.llama.test_modeling_llama.LlamaModelTest.test_model_rope_scaling with Llama->Phi def test_model_rope_scaling(self): config, _ = self.model_tester.prepare_config_and_inputs_for_common() hidden_size = config.hidden_size @@ -409,6 +410,10 @@ def test_model_rope_scaling(self): # Inputs x = torch.randn(1, dtype=torch.float32, device=torch_device) # used exlusively to get the dtype and the device + position_ids_short = torch.arange(short_input_length, dtype=torch.long, device=torch_device) + position_ids_short = position_ids_short.unsqueeze(0) + position_ids_long = torch.arange(long_input_length, dtype=torch.long, device=torch_device) + position_ids_long = position_ids_long.unsqueeze(0) # Sanity check original RoPE original_rope = PhiRotaryEmbedding( @@ -416,10 +421,10 @@ def test_model_rope_scaling(self): max_position_embeddings=config.max_position_embeddings, base=config.rope_theta, ).to(torch_device) - original_cos_short, original_sin_short = original_rope(x, short_input_length) - original_cos_long, original_sin_long = original_rope(x, long_input_length) - torch.testing.assert_close(original_cos_short, original_cos_long[:short_input_length, :]) - torch.testing.assert_close(original_sin_short, original_sin_long[:short_input_length, :]) + original_cos_short, original_sin_short = original_rope(x, position_ids_short) + original_cos_long, original_sin_long = original_rope(x, position_ids_long) + torch.testing.assert_close(original_cos_short, original_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(original_sin_short, original_sin_long[:, :short_input_length, :]) # Sanity check linear RoPE scaling # New position "x" should match original position with index "x/scaling_factor" @@ -429,14 +434,14 @@ def test_model_rope_scaling(self): base=config.rope_theta, scaling_factor=scaling_factor, ).to(torch_device) - linear_cos_short, linear_sin_short = linear_scaling_rope(x, short_input_length) - linear_cos_long, linear_sin_long = linear_scaling_rope(x, long_input_length) - torch.testing.assert_close(linear_cos_short, linear_cos_long[:short_input_length, :]) - torch.testing.assert_close(linear_sin_short, linear_sin_long[:short_input_length, :]) + linear_cos_short, linear_sin_short = linear_scaling_rope(x, position_ids_short) + linear_cos_long, linear_sin_long = linear_scaling_rope(x, position_ids_long) + torch.testing.assert_close(linear_cos_short, linear_cos_long[:, :short_input_length, :]) + torch.testing.assert_close(linear_sin_short, linear_sin_long[:, :short_input_length, :]) for new_position in range(0, long_input_length, scaling_factor): original_position = int(new_position // scaling_factor) - torch.testing.assert_close(linear_cos_long[new_position, :], original_cos_long[original_position, :]) - torch.testing.assert_close(linear_sin_long[new_position, :], original_sin_long[original_position, :]) + torch.testing.assert_close(linear_cos_long[:, new_position, :], original_cos_long[:, original_position, :]) + torch.testing.assert_close(linear_sin_long[:, new_position, :], original_sin_long[:, original_position, :]) # Sanity check Dynamic NTK RoPE scaling # Scaling should only be observed after a long input is fed. We can observe that the frequencies increase @@ -447,8 +452,8 @@ def test_model_rope_scaling(self): base=config.rope_theta, scaling_factor=scaling_factor, ).to(torch_device) - ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, short_input_length) - ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, long_input_length) + ntk_cos_short, ntk_sin_short = ntk_scaling_rope(x, position_ids_short) + ntk_cos_long, ntk_sin_long = ntk_scaling_rope(x, position_ids_long) torch.testing.assert_close(ntk_cos_short, original_cos_short) torch.testing.assert_close(ntk_sin_short, original_sin_short) with self.assertRaises(AssertionError): @@ -498,6 +503,16 @@ def test_flash_attn_2_generate_padding_right(self): @slow @require_torch class PhiIntegrationTest(unittest.TestCase): + # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) + # Depending on the hardware we get different logits / generations + cuda_compute_capability_major_version = None + + @classmethod + def setUpClass(cls): + if is_torch_available() and torch.cuda.is_available(): + # 8 is for A100 / A10 and 7 for T4 + cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + def test_model_phi_1_logits(self): input_ids = { "input_ids": torch.tensor( @@ -564,3 +579,53 @@ def test_phi_2_generation(self): ] self.assertListEqual(output_text, EXPECTED_OUTPUT) + + @slow + @require_torch_gpu + def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. + EXPECTED_TEXT_COMPLETION = { + 8: [ + "Simply put, the theory of relativity states that \n\n$$\nE = mc^2\n$$\n\nwhere $E$ is the energy of an object, $m$ is its mass, and $c$ is the speed of light", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my fries, I love it on my burgers, I love it on my hot dogs, I love it on my chicken nuggets, I love it", + ], + 7: [ + "Simply put, the theory of relativity states that \n\n$$\nE = mc^2\n$$\n\nwhere $E$ is the energy of an object, $m$ is its mass, and $c$ is the speed of light", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my fries, I love it on my burgers, I love it on my hot dogs, I love it on my chicken nuggets, I love it", + ], + } + + prompts = [ + "Simply put, the theory of relativity states that ", + "My favorite all time favorite condiment is ketchup.", + ] + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", pad_token="<|endoftext|>", padding_side="left") + model = PhiForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype=torch.float16).to(torch_device) + inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) + + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) diff --git a/tests/models/phi3/test_modeling_phi3.py b/tests/models/phi3/test_modeling_phi3.py index ad9c4c46aa93..6d04d4fb73f0 100644 --- a/tests/models/phi3/test_modeling_phi3.py +++ b/tests/models/phi3/test_modeling_phi3.py @@ -22,6 +22,7 @@ from transformers import Phi3Config, is_torch_available, set_seed from transformers.testing_utils import ( require_torch, + require_torch_gpu, slow, torch_device, ) @@ -396,6 +397,16 @@ def test_model_rope_scaling_from_config(self, scaling_type): @slow @require_torch class Phi3IntegrationTest(unittest.TestCase): + # This variable is used to determine which CUDA device are we using for our runners (A10 or T4) + # Depending on the hardware we get different logits / generations + cuda_compute_capability_major_version = None + + @classmethod + def setUpClass(cls): + if is_torch_available() and torch.cuda.is_available(): + # 8 is for A100 / A10 and 7 for T4 + cls.cuda_compute_capability_major_version = torch.cuda.get_device_capability()[0] + def test_model_phi3_mini_4k_instruct_logits(self): input_ids = { "input_ids": torch.tensor( @@ -471,3 +482,57 @@ def test_phi3_mini_128k_instruct_generation(self): ] self.assertListEqual(output_text, EXPECTED_OUTPUT) + + @slow + @require_torch_gpu + def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + # if version.parse(torch.__version__) < version.parse("2.3.0"): + # self.skipTest("This test requires torch >= 2.3 to run.") + + NUM_TOKENS_TO_GENERATE = 40 + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. + EXPECTED_TEXT_COMPLETION = { + 8: [ + "Can you provide ways to eat combinations of bananas and dragonfruits? Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some ideas for eating combinations of bananas and dragonfruits:\n\n1", + ], + 7: [ + "Can you provide ways to eat combinations of bananas and dragonfruits? Certainly! Bananas and dragonfruits can be combined in various delicious ways. Here are some ideas for eating combinations of bananas and dragonfruits:\n\n1", + ], + } + tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") + model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct", torch_dtype=torch.float16).to( + torch_device + ) + messages = [ + { + "role": "system", + "content": "You are a helpful digital assistant. Please provide safe, ethical and accurate information to the user.", + }, + {"role": "user", "content": "Can you provide ways to eat combinations of bananas and dragonfruits?"}, + ] + inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to( + torch_device + ) + + # Dynamic Cache + generated_ids = model.generate(inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output + + # Static Cache + generated_ids = model.generate( + inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f81bd5c8c38c..bbc141404fcb 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -4414,7 +4414,7 @@ def test_custom_4d_attention_mask(self): if not model_class._supports_static_cache: self.skipTest(f"{model_class.__name__} is not guaranteed to work with custom 4D attention masks") config, _ = self.model_tester.prepare_config_and_inputs_for_common() - if getattr(config, "sliding_window", 0) > 0: + if getattr(config, "sliding_window", None) is not None: self.skipTest(f"{model_class.__name__} with sliding window attention is not supported by this test") model = model_class(config).to(device=torch_device, dtype=torch.float32)