@@ -977,13 +977,14 @@ class StaticCache(Cache):
977
977
Parameters:
978
978
config (`PretrainedConfig`):
979
979
The configuration file defining the shape-related attributes required to initialize the static cache.
980
- max_batch_size (`int`):
981
- The maximum batch size with which the model will be used.
980
+ batch_size (`int`):
981
+ The batch size with which the model will be used. Note that a new instance must be instantiated if a
982
+ smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search
982
983
max_cache_len (`int`):
983
984
The maximum sequence length with which the model will be used.
984
- device (`torch.device`):
985
+ device (`torch.device` or `str` ):
985
986
The device on which the cache should be initialized. Should be the same as the layer.
986
- dtype (*optional*, defaults to `torch.float32`):
987
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
987
988
The default `dtype` to use when initializing the layer.
988
989
989
990
Example:
@@ -999,22 +1000,37 @@ class StaticCache(Cache):
999
1000
>>> # Prepare a cache class and pass it to model's forward
1000
1001
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
1001
1002
>>> max_generated_length = inputs.input_ids.shape[1] + 10
1002
- >>> past_key_values = StaticCache(config=model.config, max_batch_size =1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
1003
+ >>> past_key_values = StaticCache(config=model.config, batch_size =1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
1003
1004
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1004
1005
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
1005
1006
```
1006
1007
"""
1007
1008
1008
- def __init__ (self , config : PretrainedConfig , max_batch_size : int , max_cache_len : int , device , dtype = None ) -> None :
1009
+ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
1010
+ def __init__ (
1011
+ self ,
1012
+ config : PretrainedConfig ,
1013
+ batch_size : int = None ,
1014
+ max_cache_len : int = None ,
1015
+ device : torch .device = None ,
1016
+ dtype : torch .dtype = torch .float32 ,
1017
+ max_batch_size : Optional [int ] = None ,
1018
+ ) -> None :
1009
1019
super ().__init__ ()
1010
- self .max_batch_size = max_batch_size
1020
+ if max_batch_size is not None :
1021
+ logger .warning_once (
1022
+ f"The 'max_batch_size' argument of { self .__class__ .__name__ } is deprecated and will be removed in "
1023
+ "v4.46. Use the more precisely named 'batch_size' argument instead."
1024
+ )
1025
+
1026
+ self .batch_size = batch_size or max_batch_size
1011
1027
self .max_cache_len = config .max_position_embeddings if max_cache_len is None else max_cache_len
1012
1028
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
1013
1029
self .head_dim = (
1014
1030
config .head_dim if hasattr (config , "head_dim" ) else config .hidden_size // config .num_attention_heads
1015
1031
)
1016
1032
1017
- self .dtype = dtype if dtype is not None else torch . float32
1033
+ self .dtype = dtype
1018
1034
self .num_key_value_heads = (
1019
1035
config .num_attention_heads
1020
1036
if getattr (config , "num_key_value_heads" , None ) is None
@@ -1024,7 +1040,7 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
1024
1040
self .key_cache : List [torch .Tensor ] = []
1025
1041
self .value_cache : List [torch .Tensor ] = []
1026
1042
# Note: There will be significant perf decrease if switching to use 5D tensors instead.
1027
- cache_shape = (max_batch_size , self .num_key_value_heads , self .max_cache_len , self .head_dim )
1043
+ cache_shape = (self . batch_size , self .num_key_value_heads , self .max_cache_len , self .head_dim )
1028
1044
for idx in range (config .num_hidden_layers ):
1029
1045
new_layer_key_cache = torch .zeros (cache_shape , dtype = self .dtype , device = device )
1030
1046
new_layer_value_cache = torch .zeros (cache_shape , dtype = self .dtype , device = device )
@@ -1130,13 +1146,14 @@ class SlidingWindowCache(StaticCache):
1130
1146
Parameters:
1131
1147
config (`PretrainedConfig`):
1132
1148
The configuration file defining the shape-related attributes required to initialize the static cache.
1133
- max_batch_size (`int`):
1134
- The maximum batch size with which the model will be used.
1149
+ batch_size (`int`):
1150
+ The batch size with which the model will be used. Note that a new instance must be instantiated if a
1151
+ smaller batch size is used.
1135
1152
max_cache_len (`int`):
1136
1153
The maximum sequence length with which the model will be used.
1137
- device (`torch.device`):
1154
+ device (`torch.device` or `str` ):
1138
1155
The device on which the cache should be initialized. Should be the same as the layer.
1139
- dtype (*optional*, defaults to `torch.float32`):
1156
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
1140
1157
The default `dtype` to use when initializing the layer.
1141
1158
1142
1159
Example:
@@ -1152,13 +1169,22 @@ class SlidingWindowCache(StaticCache):
1152
1169
>>> # Prepare a cache class and pass it to model's forward
1153
1170
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
1154
1171
>>> max_generated_length = inputs.input_ids.shape[1] + 10
1155
- >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size =1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
1172
+ >>> past_key_values = SlidingWindowCache(config=model.config, batch_size =1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
1156
1173
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1157
1174
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
1158
1175
```
1159
1176
"""
1160
1177
1161
- def __init__ (self , config : PretrainedConfig , max_batch_size : int , max_cache_len : int , device , dtype = None ) -> None :
1178
+ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
1179
+ def __init__ (
1180
+ self ,
1181
+ config : PretrainedConfig ,
1182
+ batch_size : int = None ,
1183
+ max_cache_len : int = None ,
1184
+ device : torch .device = None ,
1185
+ dtype : torch .dtype = torch .float32 ,
1186
+ max_batch_size : Optional [int ] = None ,
1187
+ ) -> None :
1162
1188
super ().__init__ ()
1163
1189
if not hasattr (config , "sliding_window" ) or config .sliding_window is None :
1164
1190
raise ValueError (
@@ -1168,7 +1194,12 @@ def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len:
1168
1194
)
1169
1195
max_cache_len = min (config .sliding_window , max_cache_len )
1170
1196
super ().__init__ (
1171
- config = config , max_batch_size = max_batch_size , max_cache_len = max_cache_len , device = device , dtype = dtype
1197
+ config = config ,
1198
+ batch_size = batch_size ,
1199
+ max_cache_len = max_cache_len ,
1200
+ device = device ,
1201
+ dtype = dtype ,
1202
+ max_batch_size = max_batch_size ,
1172
1203
)
1173
1204
1174
1205
def update (
@@ -1407,13 +1438,14 @@ class HybridCache(Cache):
1407
1438
Parameters:
1408
1439
config (`PretrainedConfig):
1409
1440
The configuration file defining the shape-related attributes required to initialize the static cache.
1410
- max_batch_size (`int`):
1411
- The maximum batch size with which the model will be used.
1441
+ batch_size (`int`):
1442
+ The batch size with which the model will be used. Note that a new instance must be instantiated if a
1443
+ smaller batch size is used.
1412
1444
max_cache_len (`int`):
1413
1445
The maximum sequence length with which the model will be used.
1414
- device (`torch.device`, *optional*, defaults to `"cpu"`):
1446
+ device (`torch.device` or `str` , *optional*, defaults to `"cpu"`):
1415
1447
The device on which the cache should be initialized. Should be the same as the layer.
1416
- dtype (*optional*, defaults to `torch.float32`):
1448
+ dtype (torch.dtype, *optional*, defaults to `torch.float32`):
1417
1449
The default `dtype` to use when initializing the layer.
1418
1450
1419
1451
Example:
@@ -1429,28 +1461,42 @@ class HybridCache(Cache):
1429
1461
>>> # Prepare a cache class and pass it to model's forward
1430
1462
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
1431
1463
>>> max_generated_length = inputs.input_ids.shape[1] + 10
1432
- >>> past_key_values = HybridCache(config=model.config, max_batch_size =1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
1464
+ >>> past_key_values = HybridCache(config=model.config, batch_size =1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
1433
1465
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1434
1466
>>> past_kv_length = outputs.past_key_values # access cache filled with key/values from generation
1435
1467
```
1436
1468
"""
1437
1469
1438
- def __init__ (self , config : PretrainedConfig , max_batch_size , max_cache_len , device = "cpu" , dtype = None ) -> None :
1470
+ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
1471
+ def __init__ (
1472
+ self ,
1473
+ config : PretrainedConfig ,
1474
+ batch_size : int = None ,
1475
+ max_cache_len : int = None ,
1476
+ device : Union [torch .device , str ] = "cpu" ,
1477
+ dtype : torch .dtype = torch .float32 ,
1478
+ max_batch_size : Optional [int ] = None ,
1479
+ ) -> None :
1439
1480
super ().__init__ ()
1481
+ if max_batch_size is not None :
1482
+ logger .warning_once (
1483
+ f"The 'max_batch_size' argument of { self .__class__ .__name__ } is deprecated and will be removed in "
1484
+ "v4.46. Use the more precisely named 'batch_size' argument instead."
1485
+ )
1440
1486
if not hasattr (config , "sliding_window" ) or config .sliding_window is None :
1441
1487
raise ValueError (
1442
1488
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
1443
1489
"sliding window attention, please check if there is a `sliding_window` field in the model "
1444
1490
"config and it's not set to None."
1445
1491
)
1446
1492
self .max_cache_len = max_cache_len
1447
- self .max_batch_size = max_batch_size
1493
+ self .batch_size = batch_size or max_batch_size
1448
1494
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
1449
1495
self .head_dim = (
1450
1496
config .head_dim if hasattr (config , "head_dim" ) else config .hidden_size // config .num_attention_heads
1451
1497
)
1452
1498
1453
- self .dtype = dtype if dtype is not None else torch . float32
1499
+ self .dtype = dtype
1454
1500
self .num_key_value_heads = (
1455
1501
config .num_attention_heads if config .num_key_value_heads is None else config .num_key_value_heads
1456
1502
)
@@ -1459,9 +1505,9 @@ def __init__(self, config: PretrainedConfig, max_batch_size, max_cache_len, devi
1459
1505
)
1460
1506
self .key_cache : List [torch .Tensor ] = []
1461
1507
self .value_cache : List [torch .Tensor ] = []
1462
- global_cache_shape = (max_batch_size , self .num_key_value_heads , max_cache_len , self .head_dim )
1508
+ global_cache_shape = (self . batch_size , self .num_key_value_heads , max_cache_len , self .head_dim )
1463
1509
sliding_cache_shape = (
1464
- max_batch_size ,
1510
+ self . batch_size ,
1465
1511
self .num_key_value_heads ,
1466
1512
min (config .sliding_window , max_cache_len ),
1467
1513
self .head_dim ,
@@ -1564,11 +1610,12 @@ class MambaCache:
1564
1610
Arguments:
1565
1611
config (`PretrainedConfig):
1566
1612
The configuration file defining the shape-related attributes required to initialize the static cache.
1567
- max_batch_size (`int`):
1568
- The maximum batch size with which the model will be used.
1569
- dtype (*optional*, defaults to `torch.float16`):
1613
+ batch_size (`int`):
1614
+ The batch size with which the model will be used. Note that a new instance must be instantiated if a
1615
+ smaller batch size is used.
1616
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
1570
1617
The default `dtype` to use when initializing the layer.
1571
- device (`torch.device`, *optional*):
1618
+ device (`torch.device` or `str` , *optional*):
1572
1619
The device on which the cache should be initialized. Should be the same as the layer.
1573
1620
1574
1621
Attributes:
@@ -1596,37 +1643,43 @@ class MambaCache:
1596
1643
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")
1597
1644
1598
1645
>>> # Prepare a cache class and pass it to model's forward
1599
- >>> past_key_values = MambaCache(config=model.config, max_batch_size =1, device=model.device, dtype=model.dtype)
1646
+ >>> past_key_values = MambaCache(config=model.config, batch_size =1, device=model.device, dtype=model.dtype)
1600
1647
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
1601
1648
>>> past_kv = outputs.past_key_values
1602
1649
```
1603
1650
"""
1604
1651
1652
+ # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
1605
1653
def __init__ (
1606
1654
self ,
1607
1655
config : PretrainedConfig ,
1608
- max_batch_size : int ,
1656
+ batch_size : int = None ,
1609
1657
dtype : torch .dtype = torch .float16 ,
1610
- device : Optional [str ] = None ,
1611
- ** kwargs ,
1658
+ device : Optional [Union [ torch . device , str ] ] = None ,
1659
+ max_batch_size : Optional [ int ] = None ,
1612
1660
):
1661
+ if max_batch_size is not None :
1662
+ logger .warning_once (
1663
+ f"The 'max_batch_size' argument of { self .__class__ .__name__ } is deprecated and will be removed in "
1664
+ "v4.46. Use the more precisely named 'batch_size' argument instead."
1665
+ )
1613
1666
self .dtype = dtype
1614
- self .max_batch_size = max_batch_size
1667
+ self .batch_size = batch_size or max_batch_size
1615
1668
self .intermediate_size = config .intermediate_size
1616
1669
self .ssm_state_size = config .state_size
1617
1670
self .conv_kernel_size = config .conv_kernel
1618
1671
1619
1672
self .conv_states : torch .Tensor = torch .zeros (
1620
1673
config .num_hidden_layers ,
1621
- self .max_batch_size ,
1674
+ self .batch_size ,
1622
1675
self .intermediate_size ,
1623
1676
self .conv_kernel_size ,
1624
1677
device = device ,
1625
1678
dtype = dtype ,
1626
1679
)
1627
1680
self .ssm_states : torch .Tensor = torch .zeros (
1628
1681
config .num_hidden_layers ,
1629
- self .max_batch_size ,
1682
+ self .batch_size ,
1630
1683
self .intermediate_size ,
1631
1684
self .ssm_state_size ,
1632
1685
device = device ,
0 commit comments