-
Notifications
You must be signed in to change notification settings - Fork 31k
[Core generation] Adds support for static KV cache
#27931
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 86 commits
17b8b38
80ef815
2639b5d
9f2e1e4
271260c
5be65ff
c6b6d35
90224dd
24ffbfb
cd95e98
7cd3655
eeebc66
5819a85
216dd8f
a48ae88
aeefa26
e05f8da
07f5cdc
f769b0e
bb6a160
dd1e42c
a3b0003
dacd0ff
021f674
98af852
8594670
60af293
05166fe
9c1a3b4
d5395af
a20a183
bce7653
0e59f70
e573000
fce7e46
24ef3cf
344309f
42e5a38
6637755
6ec92df
d784927
0332d3f
4e40703
770c5e6
7bd1fca
25fd440
4c3220f
d51acfa
2b2e0c2
4b93379
ab07e80
77ccdce
ad6832a
1cb6a16
d044263
c838352
e80b6a1
8308809
0132a2c
87b3064
4d88605
011931e
e838f57
c23815a
c985064
6a954d5
45760d6
64f5455
f103454
c7b5d2c
538ccf0
ce42624
33832d2
8a53f53
f560fe5
5f90ed4
e5c731e
b6c9180
8de700f
e92b1a0
d9f7f16
d98f277
65217de
a219236
7a6b57d
2822423
70df80e
b4fbf3f
70d5ded
ec22fb1
9968b0e
dc885ca
0c2a66f
e087adc
c0cf294
da720c8
8f4c49d
c22d564
89929b9
d4b24ee
d7e400e
9d9eec3
4eb8a9e
dad35d6
6f516a0
f25ac8e
17f0350
b91efbb
256c324
327b77a
8509e91
60aa86d
7de4ace
453df24
0a1f8d2
040b2f1
1763ec7
c4242c8
af097af
5bbde6f
7f8ca33
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,8 +1,12 @@ | ||
| from dataclasses import dataclass | ||
| from typing import Any, Dict, List, Optional, Tuple | ||
|
|
||
| import torch | ||
|
|
||
| from .configuration_utils import PretrainedConfig | ||
|
|
||
|
|
||
| @dataclass | ||
| class Cache: | ||
| """ | ||
| Base, abstract class for all caches. The actual data structure is specific to each subclass. | ||
|
|
@@ -320,3 +324,74 @@ def reorder_cache(self, beam_idx: torch.LongTensor): | |
| self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
| device = self.value_cache[layer_idx].device | ||
| self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
|
|
||
|
|
||
| class StaticCache(Cache): | ||
| def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device) -> None: | ||
| super().__init__() | ||
| self.max_batch_size = max_batch_size | ||
| self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len | ||
| self.head_dim = config.hidden_size // config.num_attention_heads | ||
| self.num_heads = config.num_attention_heads | ||
| self.dtype = config.torch_dtype if config.torch_dtype is not None else torch.float32 | ||
|
||
|
|
||
| cache_shape = (max_batch_size, self.num_heads, self.max_cache_len, self.head_dim) | ||
| self.key_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) | ||
| self.value_cache: torch.Tensor = torch.zeros(cache_shape, dtype=self.dtype, device=device) | ||
ArthurZucker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| self.seen_tokens = 0 | ||
|
|
||
| def update( | ||
| self, | ||
| key_states: torch.Tensor, | ||
| value_states: torch.Tensor, | ||
| layer_idx: int, | ||
| cache_kwargs: Optional[Dict[str, Any]] = None, | ||
ArthurZucker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | ||
| It is VERY important to index using a tensor, otherwise you introduce a copy to the device. | ||
|
|
||
| Parameters: | ||
| key_states (`torch.Tensor`): | ||
| The new key states to cache. | ||
| value_states (`torch.Tensor`): | ||
| The new value states to cache. | ||
| layer_idx (`int`): | ||
| The index of the layer to cache the states for. | ||
| cache_kwargs (`Dict[str, Any]`, `optional`): | ||
| Additional arguments for the cache subclass. The `StaticCache` needs to update the attention | ||
| mask to make sure the unseen tokens are not attended to. | ||
|
|
||
| Return: | ||
| A tuple containing the updated key and value states. | ||
| """ | ||
| position_ids = cache_kwargs.get("position_ids") | ||
|
|
||
| k_out = self.key_cache | ||
| v_out = self.value_cache | ||
|
|
||
| k_out[:, :, position_ids] = key_states | ||
| v_out[:, :, position_ids] = value_states | ||
|
|
||
| self.seen_tokens += key_states.shape[-2] | ||
| return k_out, v_out | ||
|
|
||
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
| """Returns the sequence length of the cached states that were seen by the model. A layer index can be optionally passed.""" | ||
| return self.seen_tokens | ||
|
|
||
| def get_max_length(self) -> Optional[int]: | ||
| """Returns the maximum sequence length of the cached states. DynamicCache does not have a maximum length.""" | ||
| return self.max_cache_len | ||
|
|
||
| def reorder_cache(self, beam_idx: torch.LongTensor): | ||
| """Reorders the cache for beam search, given the selected beam indices.""" | ||
| device = self.key_cache.device | ||
| self.key_cache = self.key_cache.index_select(0, beam_idx.to(device)) | ||
| device = self.value_cache.device | ||
| self.value_cache = self.value_cache.index_select(0, beam_idx.to(device)) | ||
|
|
||
| def to_legacy_cache(self): | ||
| """Dummy function for BC should not be used""" | ||
| return None | ||
ArthurZucker marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Uh oh!
There was an error while loading. Please reload this page.