Skip to content

Commit cf32ee1

Browse files
ganteArthurZucker
andauthored
Cache: use batch_size instead of max_batch_size (#32657)
* more precise name * better docstrings * Update src/transformers/cache_utils.py Co-authored-by: Arthur <[email protected]> --------- Co-authored-by: Arthur <[email protected]>
1 parent 8f9fa3b commit cf32ee1

File tree

9 files changed

+112
-54
lines changed

9 files changed

+112
-54
lines changed

docs/source/en/llm_optims.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ model.generation_config.max_new_tokens = 16
9999

100100
past_key_values = StaticCache(
101101
config=model.config,
102-
max_batch_size=1,
102+
batch_size=1,
103103
# If you plan to reuse the cache, make sure the cache length is large enough for all cases
104104
max_cache_len=prompt_length+(model.generation_config.max_new_tokens*2),
105105
device=model.device,
@@ -161,7 +161,7 @@ There are a few important things you must do to enable static kv-cache and `torc
161161
batch_size, seq_length = inputs["input_ids"].shape
162162
with torch.no_grad():
163163
past_key_values = StaticCache(
164-
config=model.config, max_batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
164+
config=model.config, batch_size=2, max_cache_len=4096, device=torch_device, dtype=model.dtype
165165
)
166166
cache_position = torch.arange(seq_length, device=torch_device)
167167
generated_ids = torch.zeros(

src/transformers/cache_utils.py

Lines changed: 90 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -977,13 +977,14 @@ class StaticCache(Cache):
977977
Parameters:
978978
config (`PretrainedConfig`):
979979
The configuration file defining the shape-related attributes required to initialize the static cache.
980-
max_batch_size (`int`):
981-
The maximum batch size with which the model will be used.
980+
batch_size (`int`):
981+
The batch size with which the model will be used. Note that a new instance must be instantiated if a
982+
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search
982983
max_cache_len (`int`):
983984
The maximum sequence length with which the model will be used.
984-
device (`torch.device`):
985+
device (`torch.device` or `str`):
985986
The device on which the cache should be initialized. Should be the same as the layer.
986-
dtype (*optional*, defaults to `torch.float32`):
987+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
987988
The default `dtype` to use when initializing the layer.
988989
989990
Example:
@@ -999,22 +1000,37 @@ class StaticCache(Cache):
9991000
>>> # Prepare a cache class and pass it to model's forward
10001001
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
10011002
>>> max_generated_length = inputs.input_ids.shape[1] + 10
1002-
>>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
1003+
>>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
10031004
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
10041005
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
10051006
```
10061007
"""
10071008

1008-
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
1009+
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
1010+
def __init__(
1011+
self,
1012+
config: PretrainedConfig,
1013+
batch_size: int = None,
1014+
max_cache_len: int = None,
1015+
device: torch.device = None,
1016+
dtype: torch.dtype = torch.float32,
1017+
max_batch_size: Optional[int] = None,
1018+
) -> None:
10091019
super().__init__()
1010-
self.max_batch_size = max_batch_size
1020+
if max_batch_size is not None:
1021+
logger.warning_once(
1022+
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
1023+
"v4.46. Use the more precisely named 'batch_size' argument instead."
1024+
)
1025+
1026+
self.batch_size = batch_size or max_batch_size
10111027
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
10121028
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
10131029
self.head_dim = (
10141030
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
10151031
)
10161032

1017-
self.dtype = dtype if dtype is not None else torch.float32
1033+
self.dtype = dtype
10181034
self.num_key_value_heads = (
10191035
config.num_attention_heads
10201036
if getattr(config, "num_key_value_heads", None) is None
@@ -1024,7 +1040,7 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
10241040
self.key_cache: List[torch.Tensor] = []
10251041
self.value_cache: List[torch.Tensor] = []
10261042
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
1027-
cache_shape = (max_batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
1043+
cache_shape = (self.batch_size, self.num_key_value_heads, self.max_cache_len, self.head_dim)
10281044
for idx in range(config.num_hidden_layers):
10291045
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
10301046
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=device)
@@ -1130,13 +1146,14 @@ class SlidingWindowCache(StaticCache):
11301146
Parameters:
11311147
config (`PretrainedConfig`):
11321148
The configuration file defining the shape-related attributes required to initialize the static cache.
1133-
max_batch_size (`int`):
1134-
The maximum batch size with which the model will be used.
1149+
batch_size (`int`):
1150+
The batch size with which the model will be used. Note that a new instance must be instantiated if a
1151+
smaller batch size is used.
11351152
max_cache_len (`int`):
11361153
The maximum sequence length with which the model will be used.
1137-
device (`torch.device`):
1154+
device (`torch.device` or `str`):
11381155
The device on which the cache should be initialized. Should be the same as the layer.
1139-
dtype (*optional*, defaults to `torch.float32`):
1156+
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
11401157
The default `dtype` to use when initializing the layer.
11411158
11421159
Example:
@@ -1152,13 +1169,22 @@ class SlidingWindowCache(StaticCache):
11521169
>>> # Prepare a cache class and pass it to model's forward
11531170
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
11541171
>>> max_generated_length = inputs.input_ids.shape[1] + 10
1155-
>>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
1172+
>>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
11561173
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
11571174
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
11581175
```
11591176
"""
11601177

1161-
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device, dtype=None) -> None:
1178+
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
1179+
def __init__(
1180+
self,
1181+
config: PretrainedConfig,
1182+
batch_size: int = None,
1183+
max_cache_len: int = None,
1184+
device: torch.device = None,
1185+
dtype: torch.dtype = torch.float32,
1186+
max_batch_size: Optional[int] = None,
1187+
) -> None:
11621188
super().__init__()
11631189
if not hasattr(config, "sliding_window") or config.sliding_window is None:
11641190
raise ValueError(
@@ -1168,7 +1194,12 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
11681194
)
11691195
max_cache_len = min(config.sliding_window, max_cache_len)
11701196
super().__init__(
1171-
config=config, max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype
1197+
config=config,
1198+
batch_size=batch_size,
1199+
max_cache_len=max_cache_len,
1200+
device=device,
1201+
dtype=dtype,
1202+
max_batch_size=max_batch_size,
11721203
)
11731204

11741205
def update(
@@ -1407,13 +1438,14 @@ class HybridCache(Cache):
14071438
Parameters:
14081439
config (`PretrainedConfig):
14091440
The configuration file defining the shape-related attributes required to initialize the static cache.
1410-
max_batch_size (`int`):
1411-
The maximum batch size with which the model will be used.
1441+
batch_size (`int`):
1442+
The batch size with which the model will be used. Note that a new instance must be instantiated if a
1443+
smaller batch size is used.
14121444
max_cache_len (`int`):
14131445
The maximum sequence length with which the model will be used.
1414-
device (`torch.device`, *optional*, defaults to `"cpu"`):
1446+
device (`torch.device` or `str`, *optional*, defaults to `"cpu"`):
14151447
The device on which the cache should be initialized. Should be the same as the layer.
1416-
dtype (*optional*, defaults to `torch.float32`):
1448+
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
14171449
The default `dtype` to use when initializing the layer.
14181450
14191451
Example:
@@ -1429,28 +1461,42 @@ class HybridCache(Cache):
14291461
>>> # Prepare a cache class and pass it to model's forward
14301462
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
14311463
>>> max_generated_length = inputs.input_ids.shape[1] + 10
1432-
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
1464+
>>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
14331465
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
14341466
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
14351467
```
14361468
"""
14371469

1438-
def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, device="cpu", dtype=None) -> None:
1470+
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
1471+
def __init__(
1472+
self,
1473+
config: PretrainedConfig,
1474+
batch_size: int = None,
1475+
max_cache_len: int = None,
1476+
device: Union[torch.device, str] = "cpu",
1477+
dtype: torch.dtype = torch.float32,
1478+
max_batch_size: Optional[int] = None,
1479+
) -> None:
14391480
super().__init__()
1481+
if max_batch_size is not None:
1482+
logger.warning_once(
1483+
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
1484+
"v4.46. Use the more precisely named 'batch_size' argument instead."
1485+
)
14401486
if not hasattr(config, "sliding_window") or config.sliding_window is None:
14411487
raise ValueError(
14421488
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
14431489
"sliding window attention, please check if there is a `sliding_window` field in the model "
14441490
"config and it's not set to None."
14451491
)
14461492
self.max_cache_len = max_cache_len
1447-
self.max_batch_size = max_batch_size
1493+
self.batch_size = batch_size or max_batch_size
14481494
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
14491495
self.head_dim = (
14501496
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
14511497
)
14521498

1453-
self.dtype = dtype if dtype is not None else torch.float32
1499+
self.dtype = dtype
14541500
self.num_key_value_heads = (
14551501
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
14561502
)
@@ -1459,9 +1505,9 @@ def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, devi
14591505
)
14601506
self.key_cache: List[torch.Tensor] = []
14611507
self.value_cache: List[torch.Tensor] = []
1462-
global_cache_shape = (max_batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
1508+
global_cache_shape = (self.batch_size, self.num_key_value_heads, max_cache_len, self.head_dim)
14631509
sliding_cache_shape = (
1464-
max_batch_size,
1510+
self.batch_size,
14651511
self.num_key_value_heads,
14661512
min(config.sliding_window, max_cache_len),
14671513
self.head_dim,
@@ -1564,11 +1610,12 @@ class MambaCache:
15641610
Arguments:
15651611
config (`PretrainedConfig):
15661612
The configuration file defining the shape-related attributes required to initialize the static cache.
1567-
max_batch_size (`int`):
1568-
The maximum batch size with which the model will be used.
1569-
dtype (*optional*, defaults to `torch.float16`):
1613+
batch_size (`int`):
1614+
The batch size with which the model will be used. Note that a new instance must be instantiated if a
1615+
smaller batch size is used.
1616+
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
15701617
The default `dtype` to use when initializing the layer.
1571-
device (`torch.device`, *optional*):
1618+
device (`torch.device` or `str`, *optional*):
15721619
The device on which the cache should be initialized. Should be the same as the layer.
15731620
15741621
Attributes:
@@ -1596,37 +1643,43 @@ class MambaCache:
15961643
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
15971644
15981645
>>> # Prepare a cache class and pass it to model's forward
1599-
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
1646+
>>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
16001647
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
16011648
>>> past_kv = outputs.past_key_values
16021649
```
16031650
"""
16041651

1652+
# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
16051653
def __init__(
16061654
self,
16071655
config: PretrainedConfig,
1608-
max_batch_size: int,
1656+
batch_size: int = None,
16091657
dtype: torch.dtype = torch.float16,
1610-
device: Optional[str] = None,
1611-
**kwargs,
1658+
device: Optional[Union[torch.device, str]] = None,
1659+
max_batch_size: Optional[int] = None,
16121660
):
1661+
if max_batch_size is not None:
1662+
logger.warning_once(
1663+
f"The 'max_batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
1664+
"v4.46. Use the more precisely named 'batch_size' argument instead."
1665+
)
16131666
self.dtype = dtype
1614-
self.max_batch_size = max_batch_size
1667+
self.batch_size = batch_size or max_batch_size
16151668
self.intermediate_size = config.intermediate_size
16161669
self.ssm_state_size = config.state_size
16171670
self.conv_kernel_size = config.conv_kernel
16181671

16191672
self.conv_states: torch.Tensor = torch.zeros(
16201673
config.num_hidden_layers,
1621-
self.max_batch_size,
1674+
self.batch_size,
16221675
self.intermediate_size,
16231676
self.conv_kernel_size,
16241677
device=device,
16251678
dtype=dtype,
16261679
)
16271680
self.ssm_states: torch.Tensor = torch.zeros(
16281681
config.num_hidden_layers,
1629-
self.max_batch_size,
1682+
self.batch_size,
16301683
self.intermediate_size,
16311684
self.ssm_state_size,
16321685
device=device,

src/transformers/generation/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1426,7 +1426,7 @@ def _get_initial_cache_position(self, input_ids, model_kwargs):
14261426
return model_kwargs
14271427

14281428
def _get_cache(
1429-
self, cache_implementation: str, max_batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
1429+
self, cache_implementation: str, batch_size: int, max_cache_len: int, device: torch.device, model_kwargs
14301430
) -> Cache:
14311431
"""
14321432
Sets a cache for `generate`, that will persist across calls. A new cache will only be initialized a
@@ -1448,7 +1448,7 @@ def _get_cache(
14481448
need_new_cache = (
14491449
not hasattr(self, "_cache")
14501450
or (not isinstance(cache_to_check, cache_cls))
1451-
or cache_to_check.max_batch_size != max_batch_size
1451+
or cache_to_check.batch_size != batch_size
14521452
)
14531453
if cache_implementation != "mamba":
14541454
need_new_cache = need_new_cache or cache_to_check.max_cache_len < max_cache_len
@@ -1473,7 +1473,7 @@ def _get_cache(
14731473

14741474
cache_kwargs = {
14751475
"config": self.config,
1476-
"max_batch_size": max_batch_size,
1476+
"batch_size": batch_size,
14771477
"max_cache_len": max_cache_len,
14781478
"device": device,
14791479
"dtype": cache_dtype,
@@ -1812,7 +1812,7 @@ def generate(
18121812
)
18131813
model_kwargs[cache_name] = self._get_cache(
18141814
cache_implementation=generation_config.cache_implementation,
1815-
max_batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
1815+
batch_size=generation_config.num_beams * generation_config.num_return_sequences * batch_size,
18161816
max_cache_len=generation_config.max_length,
18171817
device=device,
18181818
model_kwargs=model_kwargs,

src/transformers/models/gemma2/modeling_gemma2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,7 @@ def forward(
818818
batch_size, seq_len, _ = inputs_embeds.shape
819819
past_key_values = HybridCache(
820820
self.config,
821-
max_batch_size=batch_size,
821+
batch_size=batch_size,
822822
max_cache_len=seq_len,
823823
device=self.device,
824824
dtype=inputs_embeds.dtype,

tests/models/llama/test_modeling_llama.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,7 +1040,7 @@ def test_stacked_causal_mask_static_cache(self):
10401040
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
10411041
past_key_values = StaticCache(
10421042
config=self.model.config,
1043-
max_batch_size=1,
1043+
batch_size=1,
10441044
max_cache_len=max_cache_len,
10451045
device=torch_device,
10461046
dtype=self.model.dtype,
@@ -1088,7 +1088,7 @@ def test_partial_stacked_causal_mask_static_cache(self):
10881088
max_cache_len = 16 # note that max_cache_len is greater than the attention_mask.shape[-1]
10891089
past_key_values = StaticCache(
10901090
config=self.model.config,
1091-
max_batch_size=1,
1091+
batch_size=1,
10921092
max_cache_len=max_cache_len,
10931093
device=torch_device,
10941094
dtype=self.model.dtype,

tests/models/phi3/test_modeling_phi3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,12 @@
4747
end_of_text_token = 32000
4848

4949
class Phi3MiniWithStaticCache(torch.nn.Module):
50-
def __init__(self, model: Phi3ForCausalLM, max_batch_size: int, max_seq_len: int):
50+
def __init__(self, model: Phi3ForCausalLM, batch_size: int, max_seq_len: int):
5151
super().__init__()
5252
self.model = model
5353
self.cache = StaticCache(
5454
config=model.config,
55-
max_batch_size=max_batch_size,
55+
batch_size=batch_size,
5656
max_cache_len=max_seq_len,
5757
device=self.model.device,
5858
dtype=self.model.dtype,

tests/quantization/aqlm_integration/test_aqlm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def decode_one_tokens(model, cur_token, input_pos, cache_position, past_key_valu
216216
# Setup static KV cache for generation
217217
past_key_values = StaticCache(
218218
config=self.quantized_model.config,
219-
max_batch_size=1,
219+
batch_size=1,
220220
max_cache_len=seq_length + self.max_new_tokens + 1,
221221
device=torch_device,
222222
dtype=self.quantized_model.config._pre_quantization_dtype,

0 commit comments

Comments
 (0)