diff --git a/docs/source/en/internal/generation_utils.md b/docs/source/en/internal/generation_utils.md index 7270af049c32..e6872efe7308 100644 --- a/docs/source/en/internal/generation_utils.md +++ b/docs/source/en/internal/generation_utils.md @@ -362,3 +362,4 @@ A [`Constraint`] can be used to force the generation to include specific tokens [[autodoc]] StaticCache - update - get_seq_length + - reorder_cache diff --git a/docs/source/en/llm_optims.md b/docs/source/en/llm_optims.md index f1dc6d5f23ce..4b44c1d78c81 100644 --- a/docs/source/en/llm_optims.md +++ b/docs/source/en/llm_optims.md @@ -65,13 +65,12 @@ tokenizer.batch_decode(outputs, skip_special_tokens=True) ['The theory of special relativity states 1. The speed of light is constant in all inertial reference'] ``` - - +Under the hood, `generate` will attempt to reuse the same cache object, removing the need for re-compilation at each call. However, if the batch size or the maximum output length increase between calls, the cache will have to be reinitialized, triggering a new compilation. -> [!WARNING] -> The `_setup_cache` method is an internal and private method that is still under development. This means it may not be backward compatible and the API design may change in the future. + + -The `_setup_cache` method doesn't support [`~GenerationMixin.generate`] yet, so this method is a bit more involved. You'll need to write your own function to decode the next token given the current token and position and cache position of previously generated tokens. +A [`StaticCache`] object can be passed to the model's forward pass under the `past_key_values` argument, enabling the use of this object as a static kv-cache. Using this strategy, you can write your own function to decode the next token given the current token and position and cache position of previously generated tokens. You can also pass the [`StaticCache`] object to [`~GenerationMixin.generate`] and use it across calls, like you would do with a dynamic cache. ```py from transformers import LlamaTokenizer, LlamaForCausalLM, StaticCache, logging @@ -90,17 +89,22 @@ tokenizer = LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", pad_token model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", device_map="sequential") inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) -def decode_one_tokens(model, cur_token, input_pos, cache_position): +def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values): logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True + cur_token, + position_ids=input_pos, + cache_position=cache_position, + past_key_values=past_key_values, + return_dict=False, + use_cache=True )[0] new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] return new_token ``` -There are a few important things you must do to enable static kv-cache and torch.compile with the `_setup_cache` method: +There are a few important things you must do to enable static kv-cache and torch.compile with the `StaticCache` method: -1. Access the model's `_setup_cache` method and pass it the [`StaticCache`] class. This is a more flexible method because it allows you to configure parameters like the maximum batch size and sequence length. +1. Initialize the [`StaticCache`] instance before using the model for inference. There you can configure parameters like the maximum batch size and sequence length. 2. Call torch.compile on the model to compile the forward pass with the static kv-cache. @@ -109,24 +113,28 @@ There are a few important things you must do to enable static kv-cache and torch ```py batch_size, seq_length = inputs["input_ids"].shape with torch.no_grad(): - model._setup_cache(StaticCache, 2, max_cache_len=4096) - cache_position = torch.arange(seq_length, device=torch_device) - generated_ids = torch.zeros( - batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device - ) - generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) - - logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - generated_ids[:, seq_length] = next_token[:, 0] - - decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) - cache_position = torch.tensor([seq_length + 1], device=torch_device) - for _ in range(1, NUM_TOKENS_TO_GENERATE): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) - generated_ids[:, cache_position] = next_token.int() - cache_position += 1 + past_key_values = StaticCache( + config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype + ) + cache_position = torch.arange(seq_length, device=torch_device) + generated_ids = torch.zeros( + batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device + ) + generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) + + logits = model( + **inputs, cache_position=cache_position, past_key_values=past_key_values,return_dict=False, use_cache=True + )[0] + next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] + generated_ids[:, seq_length] = next_token[:, 0] + + decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) + cache_position = torch.tensor([seq_length + 1], device=torch_device) + for _ in range(1, NUM_TOKENS_TO_GENERATE): + with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): + next_token = decode_one_tokens(model, next_token.clone(), None, cache_position, past_key_values) + generated_ids[:, cache_position] = next_token.int() + cache_position += 1 text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) text @@ -134,6 +142,9 @@ text 'My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p'] ``` +> [!TIP] +> If you want to reuse the [`StaticCache`] object on a new prompt, be sure to reset its contents with the `.reset()` method + diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 2ed663b26256..ceca9d3eeb35 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -44,6 +44,7 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") def get_max_length(self) -> Optional[int]: @@ -61,6 +62,14 @@ def get_usable_length(self, new_seq_length: int, layer_idx: Optional[int] = 0) - return max_length - new_seq_length return previous_seq_length + def reorder_cache(self, beam_idx: torch.LongTensor): + """Reorders the cache for beam search, given the selected beam indices.""" + for layer_idx in range(len(self.key_cache)): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + @property def seen_tokens(self): logger.warning_once( @@ -150,6 +159,7 @@ def update( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` if len(self.key_cache) <= layer_idx: return 0 return self.key_cache[layer_idx].shape[-2] @@ -158,14 +168,6 @@ def get_max_length(self) -> Optional[int]: """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" return None - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" legacy_cache = () @@ -244,6 +246,7 @@ def _get_rerotation_cos_sin( def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length if len(self.key_cache) <= layer_idx: return 0 @@ -332,14 +335,6 @@ def update( return self.key_cache[layer_idx], self.value_cache[layer_idx] - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - class StaticCache(Cache): """ @@ -347,8 +342,7 @@ class StaticCache(Cache): Parameters: config (`PretrainedConfig): - The configuration file defining the `max_position_embeddings`, `hidden_size` and `num_attention_heads` - required to initialize the static cache. + The configuration file defining the shape-related attributes required to initialize the static cache. max_batch_size (`int`): The maximum batch size with which the model will be used. max_cache_len (`int`): @@ -373,9 +367,18 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads ) + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim) - self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) - self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) + for _ in range(config.num_hidden_layers): + # Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph + # breaks when updating the cache. + new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device) + torch._dynamo.mark_static_address(new_layer_key_cache) + torch._dynamo.mark_static_address(new_layer_value_cache) + self.key_cache.append(new_layer_key_cache) + self.value_cache.append(new_layer_value_cache) def update( self, @@ -394,42 +397,37 @@ def update( value_states (`torch.Tensor`): The new value states to cache. layer_idx (`int`): - The index of the layer to cache the states for. Kept for backward compatibility + The index of the layer to cache the states for. cache_kwargs (`Dict[str, Any]`, `optional`): - Additional arguments for the cache subclass. The `StaticCache` just needs the `q_len` - to know how much of the cache it should overwrite. + Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input + to know how where to write in the cache. Return: A tuple containing the updated key and value states. """ - new_cache_positions = cache_kwargs.get("cache_position") - k_out = self.key_cache - v_out = self.value_cache + cache_position = cache_kwargs.get("cache_position") + k_out = self.key_cache[layer_idx] + v_out = self.value_cache[layer_idx] - k_out[:, :, new_cache_positions] = key_states - v_out[:, :, new_cache_positions] = value_states + k_out[:, :, cache_position] = key_states + v_out[:, :, cache_position] = value_states return k_out, v_out def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: - """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" + """Returns the sequence length of the cached states that were seen by the model.""" # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's # limit the check to the first batch member and head dimension. - # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after - # https://github.com/pytorch/pytorch/issues/120248 is fixed - return (self.key_cache[0, 0].any(dim=-1)).sum() + # TODO: deprecate this function in favor of `cache_position` + return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() def get_max_length(self) -> Optional[int]: - """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" + """Returns the maximum sequence length of the cached states.""" return self.max_cache_len - def reorder_cache(self, beam_idx: torch.LongTensor): - """Reorders the cache for beam search, given the selected beam indices.""" - device = self.key_cache.device - self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) - device = self.value_cache.device - self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) - - def to_legacy_cache(self): - """Dummy function for BC. We have to keep it because otherwise the call in the forward of models will break it""" - return None + def reset(self): + """Resets the cache values while preserving the objects""" + for layer_idx in range(len(self.key_cache)): + # In-place ops prevent breaking the static address + self.key_cache[layer_idx].zero_() + self.value_cache[layer_idx].zero_() diff --git a/src/transformers/generation/utils.py b/src/transformers/generation/utils.py index 9e6a58d3e5a5..1633e41021ae 100644 --- a/src/transformers/generation/utils.py +++ b/src/transformers/generation/utils.py @@ -1310,6 +1310,34 @@ def _get_initial_cache_position(self, input_ids, model_kwargs): model_kwargs["cache_position"] = torch.arange(past_length, cur_len, device=input_ids.device) return model_kwargs + def _get_static_cache(self, max_batch_size: int, max_cache_len: int) -> StaticCache: + """ + Sets a static cache for `generate`, that will persist across calls. A new cache will only be initialized a + new `generate` call requires a larger cache. + + Returns the resulting static cache object. + """ + needs_new_cache = ( + not hasattr(self, "_static_cache") + or self._static_cache.max_batch_size < max_batch_size + or self._static_cache.max_cache_len < max_cache_len + ) + if needs_new_cache: + if hasattr(self.config, "_pre_quantization_dtype"): + cache_dtype = self.config._pre_quantization_dtype + else: + cache_dtype = self.dtype + self._static_cache = StaticCache( + config=self.config, + max_batch_size=max_batch_size, + max_cache_len=max_cache_len, + device=self.device, + dtype=cache_dtype, + ) + else: + self._static_cache.reset() # reset the cache for a new generation + return self._static_cache + @torch.no_grad() def generate( self, @@ -1514,19 +1542,19 @@ def generate( input_ids_length=input_ids_length, ) - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if generation_config.cache_implementation is not None and model_kwargs.get("past_key_values") is not None: + raise ValueError( + "Passing both `cache_implementation` (used to initialize certain caches) and `past_key_values` (a " + "Cache object) is unsupported. Please use only one of the two." + ) + elif generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: + if not self._supports_cache_class: + raise ValueError( + "This model does not support the `cache_implementation` argument. Please check the following " + "issue: https://github.com/huggingface/transformers/issues/28981." + ) if generation_config.cache_implementation == "static": - if model_kwargs.get("past_key_values", False) is not False: - raise ValueError( - "Using `past_key_values` argument with `generate()` when using a static KV cache is not supported. Please open an issue in Transformers GitHub repository." - ) - cache_cls = NEED_SETUP_CACHE_CLASSES_MAPPING["static"] - if not callable(getattr(self, "_setup_cache", None)): - raise ValueError( - "The `generation_config` defines a `cache_implementation` that is not compatible with this model." - " Make sure it has a `_setup_cache` function." - ) - self._setup_cache(cache_cls, max_batch_size=batch_size, max_cache_len=generation_config.max_length) + model_kwargs["past_key_values"] = self._get_static_cache(batch_size, generation_config.max_length) self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) @@ -1844,14 +1872,6 @@ def typeerror(): **model_kwargs, ) - if generation_config.cache_implementation in NEED_SETUP_CACHE_CLASSES_MAPPING: - if not callable(getattr(self, "_reset_cache", None)): - raise ValueError( - "A `static_cache` was used to generate but there was a failure when trying to release the cache. " - " Make sure this model implements a `_reset_cache` function." - ) - self._reset_cache() - return result def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, device: torch.device) -> bool: diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 41bb4c051692..f20edf17cf39 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -340,6 +340,11 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -732,27 +737,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - device = layer.input_layernorm.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - COHERE_INPUTS_DOCSTRING = r""" Args: @@ -896,14 +880,11 @@ def forward( inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -911,7 +892,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -980,7 +961,7 @@ def _update_causal_mask( attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -992,9 +973,12 @@ def _update_causal_mask( return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens ): @@ -1003,9 +987,9 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1027,6 +1011,10 @@ def _update_causal_mask( # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1184,13 +1172,6 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1208,8 +1189,7 @@ def prepare_inputs_for_generation( # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1249,9 +1229,6 @@ def prepare_inputs_for_generation( elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, diff --git a/src/transformers/models/dbrx/modeling_dbrx.py b/src/transformers/models/dbrx/modeling_dbrx.py index 99b865c773f8..7c2e6abbca97 100644 --- a/src/transformers/models/dbrx/modeling_dbrx.py +++ b/src/transformers/models/dbrx/modeling_dbrx.py @@ -15,7 +15,7 @@ """ PyTorch DBRX model. """ import math -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -354,6 +354,11 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs: Any, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) logger.info("Implicitly setting `output_attentions` to False as it is not supported in Flash Attention.") output_attentions = False @@ -622,6 +627,7 @@ def forward( value_states, attn_mask=causal_mask, dropout_p=self.attn_pdrop if self.training else 0.0, + is_causal=causal_mask is None and q_len > 1, ) attn_output = attn_output.transpose(1, 2).contiguous() @@ -957,28 +963,6 @@ def _init_weights(self, module: nn.Module): module.v1.data.normal_(mean=0.0, std=std) module.w2.data.normal_(mean=0.0, std=std) - def _setup_cache(self, cache_cls: Any, max_batch_size: int, max_cache_len: int): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with " - + "`attn_implementation==flash_attention_2`. Make sure to use " - + "`spda` in the mean time and open an issue at https://github.com/huggingface/transformers." - ) - - for block in self.transformer.blocks: - device = block.norm_attn_norm.norm_1.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = block.norm_attn_norm.attn.out_proj.weight.dtype - block.norm_attn_norm.attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for block in self.transformer.blocks: - block.norm_attn_norm.attn.past_key_value = None - DBRX_INPUTS_DOCSTRING = r""" Args: @@ -1131,22 +1115,18 @@ def forward( inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -1205,7 +1185,9 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache ) if not return_dict: return tuple( @@ -1221,28 +1203,45 @@ def forward( router_logits=all_router_logits, ) - # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static - # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. - # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using - # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask def _update_causal_mask( - self, attention_mask: Optional[torch.Tensor], input_tensor: torch.Tensor, cache_position: torch.Tensor - ) -> Optional[torch.Tensor]: + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + ): + # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static + # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. + # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using + # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens + ): + return None + dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(self.blocks[0].norm_attn_norm.attn, "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( - attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 ) - target_length = int(target_length) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: @@ -1259,6 +1258,10 @@ def _update_causal_mask( # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1273,17 +1276,10 @@ def _update_causal_mask( and attention_mask is not None and attention_mask.device.type == "cuda" ): - # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). - is_tracing = ( - torch.jit.is_tracing() - or isinstance(input_tensor, torch.fx.Proxy) - or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) - ) - if not is_tracing and torch.any(attention_mask != 1): - # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when - # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. - # Details: https://github.com/pytorch/pytorch/issues/110213 - causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask @@ -1431,28 +1427,35 @@ def forward( router_logits=outputs.router_logits, ) + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation def prepare_inputs_for_generation( self, - input_ids: torch.Tensor, - past_key_values: Optional[Cache] = None, - attention_mask: Optional[torch.Tensor] = None, - inputs_embeds: Optional[torch.Tensor] = None, - **kwargs: Any, - ) -> Dict[str, Any]: + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + cache_position=None, + use_cache=True, + **kwargs, + ): past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): - cache_length = past_key_values.get_seq_length() - past_length = past_key_values.seen_tokens - max_cache_length = past_key_values.get_max_length() + past_length = cache_position[0] if cache_position is not None else past_key_values.get_seq_length() + max_cache_length = ( + torch.tensor(past_key_values.get_max_length(), device=input_ids.device) + if past_key_values.get_max_length() is not None + else None + ) + cache_length = past_length if max_cache_length is None else torch.min(max_cache_length, past_length) + # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects else: cache_length = past_length = past_key_values[0][0].shape[2] max_cache_length = None # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1477,22 +1480,6 @@ def prepare_inputs_for_generation( if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] - if self.generation_config.cache_implementation == "static": - # generation with static cache - cache_position = kwargs.get("cache_position", None) - if cache_position is None: - past_length = 0 - else: - past_length = cache_position[-1] + 1 - input_ids = input_ids[:, past_length:] - position_ids = position_ids[:, past_length:] if position_ids is not None else None - - # TODO @gante we should only keep a `cache_position` in generate, and do +=1. - # same goes for position ids. Could also help with continued generation. - input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] - cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) - position_ids = position_ids.contiguous() if position_ids is not None else None - # if `inputs_embeds` are passed, we only want to use them in the 1st generation step if inputs_embeds is not None and past_key_values is None: model_inputs = {"inputs_embeds": inputs_embeds} @@ -1502,12 +1489,18 @@ def prepare_inputs_for_generation( # TODO: use `next_tokens` directly instead. model_inputs = {"input_ids": input_ids.contiguous()} + input_length = position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1] + if cache_position is None: + cache_position = torch.arange(past_length, past_length + input_length, device=input_ids.device) + elif use_cache: + cache_position = cache_position[-input_length:] + model_inputs.update( { "position_ids": position_ids, "cache_position": cache_position, "past_key_values": past_key_values, - "use_cache": kwargs.get("use_cache"), + "use_cache": use_cache, "attention_mask": attention_mask, } ) diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index e5b6b207748a..6a5cb365212a 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -332,6 +332,11 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -613,7 +618,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -715,23 +720,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - weights = layer.self_attn.o_proj.weight - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=weights.device, dtype=weights.dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - GEMMA_INPUTS_DOCSTRING = r""" Args: @@ -848,7 +836,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -877,13 +865,11 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) @@ -891,7 +877,7 @@ def forward( if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -950,7 +936,9 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -966,7 +954,7 @@ def _update_causal_mask( attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -978,9 +966,12 @@ def _update_causal_mask( return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens ): @@ -989,9 +980,9 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1013,6 +1004,10 @@ def _update_causal_mask( # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1074,7 +1069,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1166,13 +1161,6 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1190,8 +1178,7 @@ def prepare_inputs_for_generation( # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1231,9 +1218,6 @@ def prepare_inputs_for_generation( elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, @@ -1293,7 +1277,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 80d5dad3cbd8..1dbcbc76f3c2 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -29,7 +29,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss from ...activations import ACT2FN -from ...cache_utils import DynamicCache # we need __iter__ and __len__ of pkv +from ...cache_utils import Cache, DynamicCache # we need __iter__ and __len__ of pkv from ...modeling_attn_mask_utils import ( AttentionMaskConverter, ) @@ -1807,7 +1807,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index 905edf5f71a6..42000824ce34 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -428,6 +428,12 @@ def forward( cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if isinstance(past_key_value, StaticCache): + raise ValueError( + "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " + "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" + ) + output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -708,7 +714,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -809,27 +815,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - device = layer.input_layernorm.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - LLAMA_INPUTS_DOCSTRING = r""" Args: @@ -944,7 +929,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -973,23 +958,18 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -1042,7 +1022,9 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -1058,7 +1040,7 @@ def _update_causal_mask( attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1070,9 +1052,12 @@ def _update_causal_mask( return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens ): @@ -1081,9 +1066,9 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1105,6 +1090,10 @@ def _update_causal_mask( # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1164,7 +1153,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1262,13 +1251,6 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1286,8 +1268,7 @@ def prepare_inputs_for_generation( # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1327,9 +1308,6 @@ def prepare_inputs_for_generation( elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, @@ -1388,7 +1366,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -1505,7 +1483,7 @@ def forward( input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, start_positions: Optional[torch.LongTensor] = None, end_positions: Optional[torch.LongTensor] = None, diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index c013967c78f1..665e95a8fd78 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -1301,7 +1301,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index c78e907d5fdb..e5a81c4c9083 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -1525,7 +1525,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index e3b0e05127c5..7c3370ce7289 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -688,7 +688,7 @@ def forward( hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, + past_key_value: Optional[Cache] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -790,27 +790,6 @@ def _init_weights(self, module): if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() - def _setup_cache(self, cache_cls, max_batch_size, max_cache_len: Optional[int] = None): - if self.config._attn_implementation == "flash_attention_2" and cache_cls == StaticCache: - raise ValueError( - "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " - "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" - ) - - for layer in self.model.layers: - device = layer.input_layernorm.weight.device - if hasattr(self.config, "_pre_quantization_dtype"): - dtype = self.config._pre_quantization_dtype - else: - dtype = layer.self_attn.o_proj.weight.dtype - layer.self_attn.past_key_value = cache_cls( - self.config, max_batch_size, max_cache_len, device=device, dtype=dtype - ) - - def _reset_cache(self): - for layer in self.model.layers: - layer.self_attn.past_key_value = None - OLMO_INPUTS_DOCSTRING = r""" Args: @@ -926,7 +905,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, @@ -955,23 +934,18 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - past_seen_tokens = 0 - if use_cache: # kept for BC (cache positions) - if not isinstance(past_key_values, StaticCache): - past_key_values = DynamicCache.from_legacy_cache(past_key_values) - past_seen_tokens = past_key_values.get_seq_length() + if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs) + past_key_values = DynamicCache.from_legacy_cache(past_key_values) if cache_position is None: - if isinstance(past_key_values, StaticCache): - raise ValueError("cache_position is a required argument when using StaticCache.") + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) - if position_ids is None: position_ids = cache_position.unsqueeze(0) - causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_seen_tokens) + causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values) # embed positions hidden_states = inputs_embeds @@ -1024,7 +998,9 @@ def forward( next_cache = None if use_cache: next_cache = ( - next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache + next_decoder_cache.to_legacy_cache() + if isinstance(next_decoder_cache, DynamicCache) + else next_decoder_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) @@ -1041,7 +1017,7 @@ def _update_causal_mask( attention_mask: torch.Tensor, input_tensor: torch.Tensor, cache_position: torch.Tensor, - past_seen_tokens: int, + past_key_values: Cache, ): # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. @@ -1053,9 +1029,12 @@ def _update_causal_mask( return attention_mask return None - if self.config._attn_implementation == "sdpa": - # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, - # in order to dispatch on Flash Attention 2. + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + if self.config._attn_implementation == "sdpa" and not using_static_cache: if AttentionMaskConverter._ignore_causal_mask_sdpa( attention_mask, inputs_embeds=input_tensor, past_key_values_length=past_seen_tokens ): @@ -1064,9 +1043,9 @@ def _update_causal_mask( dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] - if hasattr(getattr(self.layers[0], "self_attn", {}), "past_key_value"): # static cache - target_length = self.config.max_position_embeddings - else: # dynamic cache + if using_static_cache: + target_length = past_key_values.get_max_length() + else: target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) @@ -1088,6 +1067,10 @@ def _update_causal_mask( # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: + logger.warning_once( + "Passing a 4d mask shorter than the input length is deprecated and will be removed in " + "transformers v4.42.0" + ) offset = cache_position[0] else: offset = 0 @@ -1243,13 +1226,6 @@ def prepare_inputs_for_generation( use_cache=True, **kwargs, ): - # With static cache, the `past_key_values` is None - # TODO joao: standardize interface for the different Cache classes and remove of this if - has_static_cache = False - if past_key_values is None: - past_key_values = getattr(getattr(self.model.layers[0], "self_attn", {}), "past_key_value", None) - has_static_cache = past_key_values is not None - past_length = 0 if past_key_values is not None: if isinstance(past_key_values, Cache): @@ -1267,8 +1243,7 @@ def prepare_inputs_for_generation( # Keep only the unprocessed tokens: # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where - # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as - # input) + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input) if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard @@ -1308,9 +1283,6 @@ def prepare_inputs_for_generation( elif use_cache: cache_position = cache_position[-input_length:] - if has_static_cache: - past_key_values = None - model_inputs.update( { "position_ids": position_ids, diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index c83ba413952b..8d4ad532074f 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -927,7 +927,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 13719166edf9..b23073d332e4 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -1313,7 +1313,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index f9364d130b7e..530c22a87449 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -1419,7 +1419,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 70072c91720a..ca349dca1c1b 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -1509,7 +1509,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/stablelm/modeling_stablelm.py b/src/transformers/models/stablelm/modeling_stablelm.py index 3262f2cd3c61..bc133ffb3d73 100755 --- a/src/transformers/models/stablelm/modeling_stablelm.py +++ b/src/transformers/models/stablelm/modeling_stablelm.py @@ -1299,7 +1299,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index ca4c8af23304..61e8518d659c 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -1292,7 +1292,7 @@ def forward( input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, - past_key_values: Optional[List[torch.FloatTensor]] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index dc24fd848c81..0592922e4470 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -18,11 +18,11 @@ import unittest import pytest +from packaging import version from parameterized import parameterized -from transformers import LlamaConfig, StaticCache, is_torch_available, logging, set_seed +from transformers import LlamaConfig, is_torch_available, set_seed from transformers.testing_utils import ( - CaptureLogger, require_bitsandbytes, require_flash_attn, require_read_token, @@ -684,15 +684,28 @@ def test_model_13b_greedy_generation(self): @require_torch_gpu @require_read_token def test_compile_static_cache(self): + # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 + # work as intended. See https://github.com/pytorch/pytorch/issues/121943 + if version.parse(torch.__version__) < version.parse("2.3.0"): + self.skipTest("This test requires torch >= 2.3 to run.") + NUM_TOKENS_TO_GENERATE = 40 + # Note on `EXPECTED_TEXT_COMPLETION`'s diff: the current value matches the original test if the original test + # was changed to have a cache of 53 tokens (as opposed to 4096), on Ampere GPUs. EXPECTED_TEXT_COMPLETION = { - 7: [ - "Simply put, the theory of relativity states that 1) the speed of light is constant, 2) the speed of light is the same for all observers, and 3) the laws of physics are the same for all observers.", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", - ], 8: [ - "Simply put, the theory of relativity states that 1) the speed of light is the same for all observers, and 2) the laws of physics are the same for all observers.\nThe first part of the theory of relativity", - "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + "Simply put, the theory of relativity states that 1) the speed of light is constant in all inertial " + "reference frames, and 2) the laws of physics are the same for all inertial reference frames.\nThe " + "theory of relativ", + "My favorite all time favorite condiment is ketchup. I love it on everything. I love it on my eggs, " + "my fries, my chicken, my burgers, my hot dogs, my sandwiches, my salads, my p", + ], + 7: [ + "Simply put, the theory of relativity states that 1. surely nothing is faster than light.\nThe theory " + "goes that nothing travels faster than light, but the faster you go, the slower everything else will " + "be.\nThe theory of relativity", + "My favorite all time favorite condiment is ketchup. I love it on hamburgers, hot dogs, fries, eggs, " + "and even on a good old fashioned cheeseburger. I love it on everything. I love it so", ], } @@ -706,38 +719,25 @@ def test_compile_static_cache(self): ) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) - def decode_one_tokens(model, cur_token, input_pos, cache_position): - logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True - )[0] - new_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - return new_token - - batch_size, seq_length = inputs["input_ids"].shape - with torch.no_grad(): - model._setup_cache(StaticCache, 2, max_cache_len=4096) - cache_position = torch.arange(seq_length, device=torch_device) - generated_ids = torch.zeros( - batch_size, seq_length + NUM_TOKENS_TO_GENERATE + 1, dtype=torch.int, device=torch_device - ) - generated_ids[:, cache_position] = inputs["input_ids"].to(torch_device).to(torch.int) - - logits = model(**inputs, cache_position=cache_position, return_dict=False, use_cache=True)[0] - next_token = torch.argmax(logits[:, -1], dim=-1)[:, None] - generated_ids[:, seq_length] = next_token[:, 0] - - decode_one_tokens = torch.compile(decode_one_tokens, mode="reduce-overhead", fullgraph=True) - cache_position = torch.tensor([seq_length + 1], device=torch_device) - for _ in range(1, NUM_TOKENS_TO_GENERATE): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - with CaptureLogger(logging.get_logger(__name__)) as cl: - next_token = decode_one_tokens(model, next_token.clone(), None, cache_position) - self.assertNotIn("skipping cudagraphs due to", cl.out) - generated_ids[:, cache_position] = next_token.int() - cache_position += 1 - - text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) - self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], text) + # Dynamic Cache + generated_ids = model.generate(**inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False) + dynamic_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[8], dynamic_text) # Both GPU architectures have the same output + + # Static Cache + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_text) + + # Static Cache + compile + model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) + generated_ids = model.generate( + **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" + ) + static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) @require_torch diff --git a/tests/quantization/aqlm_integration/test_aqlm.py b/tests/quantization/aqlm_integration/test_aqlm.py index 46b64573b938..3b0dd99adcd9 100644 --- a/tests/quantization/aqlm_integration/test_aqlm.py +++ b/tests/quantization/aqlm_integration/test_aqlm.py @@ -196,9 +196,14 @@ def test_quantized_model_compile(self): """ # Sample tokens greedily - def decode_one_tokens(model, cur_token, input_pos, cache_position): + def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_values): logits = model( - cur_token, position_ids=input_pos, cache_position=cache_position, return_dict=False, use_cache=True + cur_token, + position_ids=input_pos, + cache_position=cache_position, + past_key_values=past_key_values, + return_dict=False, + use_cache=True, )[0] new_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) @@ -209,7 +214,13 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): seq_length = input_ids.shape[1] # Setup static KV cache for generation - self.quantized_model._setup_cache(StaticCache, 1, max_cache_len=seq_length + self.max_new_tokens + 1) + past_key_values = StaticCache( + config=self.quantized_model.config, + max_batch_size=1, + max_cache_len=seq_length + self.max_new_tokens + 1, + device=torch_device, + dtype=self.quantized_model.config._pre_quantization_dtype, + ) # Allocate token ids to be generated and copy prefix ids cache_position = torch.arange(seq_length, device=torch_device) @@ -217,7 +228,13 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): generated_ids[:, cache_position] = input_ids.to(torch_device).to(torch.int) # Do a forward pass to fill the prefix cache and compile the kernels if necessary - logits = self.quantized_model(input_ids, cache_position=cache_position, return_dict=False, use_cache=True)[0] + logits = self.quantized_model( + input_ids, + cache_position=cache_position, + past_key_values=past_key_values, + return_dict=False, + use_cache=True, + )[0] next_token = torch.argmax(logits[:, [-1]], dim=-1).to(torch.int) generated_ids[:, [seq_length]] = next_token @@ -229,7 +246,9 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position): cache_position = torch.tensor([seq_length + 1], device=torch_device) for _ in range(1, self.max_new_tokens): with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): - next_token = decode_one_tokens(self.quantized_model, next_token.clone(), None, cache_position) + next_token = decode_one_tokens( + self.quantized_model, next_token.clone(), None, cache_position, past_key_values + ) generated_ids.index_copy_(1, cache_position, next_token) cache_position += 1