22
33# pyre-strict
44
5- from typing import Optional , Protocol , Tuple , Type
5+ from typing import Any , Dict , Optional , Protocol , Tuple , Type
66
77import torch
88
9+ from packaging .version import Version
10+ from torch import nn
11+
912
1013class CacheLike (Protocol ):
1114 """Protocol for cache-like objects."""
@@ -21,12 +24,96 @@ def from_legacy_cache(
2124 ) -> "DynamicCacheLike" : ...
2225
2326
27+ transformers_installed : bool
28+ Cache : Optional [Type [CacheLike ]]
29+ DynamicCache : Optional [Type [DynamicCacheLike ]]
30+
2431try :
25- # pyre-ignore[21]: Could not find a module corresponding to import
26- # `transformers.cache_utils`
27- from transformers .cache_utils import Cache as _Cache , DynamicCache as _DynamicCache
32+ # pyre-ignore[21]: Could not find a module corresponding to import `transformers`.
33+ import transformers # noqa: F401
34+
35+ transformers_installed = True
2836except ImportError :
29- _Cache = _DynamicCache = None
37+ transformers_installed = False
38+
39+ if transformers_installed :
40+ try :
41+ # pyre-ignore[21]: Could not find a module corresponding to import
42+ # `transformers.cache_utils`.
43+ from transformers .cache_utils import ( # noqa: F401
44+ Cache as _Cache ,
45+ DynamicCache as _DynamicCache ,
46+ )
47+
48+ Cache = _Cache
49+ # pyre-ignore[9]: Incompatible variable type: DynamicCache is declared to have
50+ # type `Optional[Type[DynamicCacheLike]]` but is used as type
51+ # `Type[_DynamicCache]`
52+ DynamicCache = _DynamicCache
53+ except ImportError :
54+ Cache = DynamicCache = None
55+ else :
56+ Cache = DynamicCache = None
57+
58+ # GenerationMixin._update_model_kwargs_for_generation
59+ # "cache_position" at v4.39.0 (only needed for models that support cache class)
60+ # "use_cache" at v4.41.0 (optional, default is True)
61+ # "cache_position" is mandatory at v4.43.0 ("use_cache" is still optional, default True)
62+ _transformers_version : Optional [Version ]
63+ if transformers_installed :
64+ _transformers_version = Version (transformers .__version__ )
65+ else :
66+ _transformers_version = None
67+
68+ _mandated_cache_version = Version ("4.43.0" )
69+ _use_cache_version = Version ("4.41.0" )
70+ _cache_position_version = Version ("4.39.0" )
71+
72+
73+ def update_model_kwargs (
74+ model_kwargs : Dict [str , Any ],
75+ model : nn .Module ,
76+ input_ids : torch .Tensor ,
77+ caching : bool ,
78+ ) -> None :
79+ if not supports_caching (model ):
80+ return
81+ assert _transformers_version is not None
82+ if caching :
83+ # Enable caching
84+ if _transformers_version >= _cache_position_version :
85+ cache_position = torch .arange (
86+ input_ids .shape [1 ], dtype = torch .int64 , device = input_ids .device
87+ )
88+ model_kwargs ["cache_position" ] = cache_position
89+ # pyre-ignore[58]: Unsupported operand `>=` is not supported for operand types
90+ # `Optional[Version]` and `Version`.
91+ if _transformers_version >= _use_cache_version :
92+ model_kwargs ["use_cache" ] = True
93+ else :
94+ # Disable caching
95+ if _transformers_version >= _use_cache_version :
96+ model_kwargs ["use_cache" ] = False
97+
3098
31- Cache : Optional [Type [CacheLike ]] = _Cache
32- DynamicCache : Optional [Type [DynamicCacheLike ]] = _DynamicCache
99+ def supports_caching (model : nn .Module ) -> bool :
100+ if not transformers_installed :
101+ # Not a transformers model
102+ return False
103+ # Cache may be optional or unsupported depending on model/version
104+ try :
105+ # pyre-ignore[21]: Could not find a module corresponding to import
106+ # `transformers.generation.utils`.
107+ from transformers .generation .utils import GenerationMixin
108+ except ImportError :
109+ return False
110+ if not isinstance (model , GenerationMixin ):
111+ # Model isn't a GenerationMixin, we don't support additional caching logic
112+ # for it
113+ return False
114+ assert _transformers_version is not None
115+ if _transformers_version >= _mandated_cache_version :
116+ # Cache is mandatory
117+ return True
118+ # Fallback on _supports_cache_class attribute
119+ return getattr (model , "_supports_cache_class" , False )
0 commit comments