-
Notifications
You must be signed in to change notification settings - Fork 30.9k
Cache: Static cache as a standalone object #30476
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall LFGTM
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
| """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" | ||
| """Returns the sequence length of the cached states that were seen by the model.""" | ||
| # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | ||
| # limit the check to the first batch member and head dimension. | ||
| # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after | ||
| # https://github.com/pytorch/pytorch/issues/120248 is fixed | ||
| return (self.key_cache[0, 0].any(dim=-1)).sum() | ||
| return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will remove this one
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's slow and not reliable, generate should never use it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(needs deprecation cycle and it's easer to do after we isolate the prefill stage, I'm going to leave it off this PR)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fine by me to deprecate
| raise ValueError( | ||
| "`static` cache implementation is not compatible with `attn_implementation==flash_attention_2` " | ||
| "make sure to use `sdpa` in the mean time, and open an issue at https://github.com/huggingface/transformers" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be compatible if we slice the q k v efficiently, but that's too much trouble
|
Taking this on to finish! |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
If you use the memory efficient kernel it's 20% slower. That's what we use by default |
|
https://gist.github.com/ArthurZucker/ae0a86ef8f841c0ef69aaa52ccbc0b03 for the benchmarks |
| # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail | ||
| # to infer the attention mask. | ||
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | ||
| using_static_cache = isinstance(past_key_values, StaticCache) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I understand it, once the StaticCache is initialized, there is no need to pass it in past_key_values argument. That's why additional condition is necessary. Suggestion:
using_static_cache = isinstance(past_key_values, StaticCache) or isinstance( getattr(self.layers[0].self_attn, "past_key_value", None), StaticCache )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@poedator This PR changes precisely the assumption you wrote: we will always need to pass the cache, after this PR it is an object that does NOT live inside the model.
This change will make the transformers' team work easier 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same comment as here: #30437 (comment) please make sure to validate these tests on the T4 and A10 runners 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was indeed a mismatch on T4 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolute great work
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
| """Returns the sequence length of the cached states that were seen by the model. `layer_idx` kept for BC""" | ||
| """Returns the sequence length of the cached states that were seen by the model.""" | ||
| # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's | ||
| # limit the check to the first batch member and head dimension. | ||
| # TODO: This is error prone, a filled cache may be `0.0`. Let's use a stateless integer instead, after | ||
| # https://github.com/pytorch/pytorch/issues/120248 is fixed | ||
| return (self.key_cache[0, 0].any(dim=-1)).sum() | ||
| return (self.key_cache[layer_idx][0, 0].any(dim=-1)).sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fine by me to deprecate
src/transformers/cache_utils.py
Outdated
| self.key_cache[layer_idx] *= 0.0 | ||
| self.value_cache[layer_idx] *= 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.key_cache[layer_idx] *= 0.0 | |
| self.value_cache[layer_idx] *= 0.0 | |
| self.key_cache[layer_idx] = 0.0 | |
| self.value_cache[layer_idx] = 0.0 |
might be faster?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
setting to a new tensor produces a graph break 💔 (I'm assuming you meant self.key_cache[layer_idx] = torch.zeros(...))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No no, I think just filling them with zeros should work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That would result in TypeError: 'float' object is not subscriptable when indexing the cache :D
But filling with zeros with tensor.zero_() works 👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok 👍🏻 let's go with that then!
| if cache_position is None: | ||
| if isinstance(past_key_values, StaticCache): | ||
| raise ValueError("cache_position is a required argument when using StaticCache.") | ||
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Arf alright, let's add maybe a TODO? as we won't be initializing with get_seq_length later on!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a todo on get_seq_length 👍
| # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in | ||
| # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail | ||
| # to infer the attention mask. | ||
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | ||
| using_static_cache = isinstance(past_key_values, StaticCache) | ||
| if self.config._attn_implementation == "sdpa" and not using_static_cache: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is new, and since we pass cahce position, let's use cache_position[0]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed in theory, can't do in practice: breaks torch.fx tests 💔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah thought so
| if using_static_cache: | ||
| target_length = past_key_values.get_max_length() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can't we always use get_max_length()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_max_length() is None in the dynamic caches
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It should be seq_length
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but alright
| @slow | ||
| @require_torch_gpu | ||
| @require_read_token | ||
| def test_compile_static_cache(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should require torch > 2.2
| # Static Cache + compile | ||
| model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) | ||
| generated_ids = model.generate( | ||
| **inputs, max_new_tokens=NUM_TOKENS_TO_GENERATE, do_sample=False, cache_implementation="static" | ||
| ) | ||
| static_compiled_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) | ||
| self.assertEqual(EXPECTED_TEXT_COMPLETION[self.cuda_compute_capability_major_version], static_compiled_text) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good thanks
* 4d mask fixes * Update custom 4D mask logic * test moved to mixin * extra tests 4d mask * upd 4d mask and StaticCache handling * added Mask4DTestHard to mistral tests * post-rebase fixes * test fixes for StaticCache * make fix-copies * upd 1 after #30476 * fix common tests * rm elif attention_mask.dim() == 4: * tests combined, fixed, mixtral supported * bigbird style chg reverted * rm if attention_mask.dim() == 2 * modeling_llama formatting chg --------- Co-authored-by: Joao Gante <[email protected]>
| # `torch==2.2` will throw an error on this test (as in other compilation tests), but torch==2.1.2 and torch>2.2 | ||
| # work as intended. See https://github.com/pytorch/pytorch/issues/121943 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and 2.2.1 works as well

What does this PR do?
Replaces the current format of
StaticCache[an object living inside a model, containing the cache for one layer] with a standalone object matching the otherCacheobjects. The new format preserves the existingtorch.compilecapabilities while being easier to manipulate, especially outside a model.In the process, removes all traces of the previous format across all models, tests, and docs.
Fixes #30417 (In place of #30437)
Fixes #30351
Benchmarks
(RTX3090, tiny-llama model,
torch==2.4.0.dev20240424+cu121)Benchmark code
commit ==

14b19c4ef365f90797e07b2a20caaaaf3901b2d2v4.39.0
