1515
1616import logging
1717import time
18+ from typing import Any , Callable , Dict , Generic , Tuple , TypeVar , Union
1819
1920import attr
2021from sortedcontainers import SortedList
2324
2425logger = logging .getLogger (__name__ )
2526
26- SENTINEL = object ()
27+ SENTINEL = object () # type: Any
2728
29+ T = TypeVar ("T" )
30+ KT = TypeVar ("KT" )
31+ VT = TypeVar ("VT" )
2832
29- class TTLCache :
33+
34+ class TTLCache (Generic [KT , VT ]):
3035 """A key/value cache implementation where each entry has its own TTL"""
3136
32- def __init__ (self , cache_name , timer = time .time ):
37+ def __init__ (self , cache_name : str , timer : Callable [[], float ] = time .time ):
3338 # map from key to _CacheEntry
34- self ._data = {}
39+ self ._data = {} # type: Dict[KT, _CacheEntry]
3540
3641 # the _CacheEntries, sorted by expiry time
3742 self ._expiry_list = SortedList () # type: SortedList[_CacheEntry]
@@ -40,26 +45,27 @@ def __init__(self, cache_name, timer=time.time):
4045
4146 self ._metrics = register_cache ("ttl" , cache_name , self , resizable = False )
4247
43- def set (self , key , value , ttl ) :
48+ def set (self , key : KT , value : VT , ttl : float ) -> None :
4449 """Add/update an entry in the cache
4550
4651 Args:
4752 key: key for this entry
4853 value: value for this entry
49- ttl (float) : TTL for this entry, in seconds
54+ ttl: TTL for this entry, in seconds
5055 """
5156 expiry = self ._timer () + ttl
5257
5358 self .expire ()
5459 e = self ._data .pop (key , SENTINEL )
55- if e != SENTINEL :
60+ if e is not SENTINEL :
61+ assert isinstance (e , _CacheEntry )
5662 self ._expiry_list .remove (e )
5763
5864 entry = _CacheEntry (expiry_time = expiry , ttl = ttl , key = key , value = value )
5965 self ._data [key ] = entry
6066 self ._expiry_list .add (entry )
6167
62- def get (self , key , default = SENTINEL ):
68+ def get (self , key : KT , default : T = SENTINEL ) -> Union [ VT , T ] :
6369 """Get a value from the cache
6470
6571 Args:
@@ -72,23 +78,23 @@ def get(self, key, default=SENTINEL):
7278 """
7379 self .expire ()
7480 e = self ._data .get (key , SENTINEL )
75- if e == SENTINEL :
81+ if e is SENTINEL :
7682 self ._metrics .inc_misses ()
77- if default == SENTINEL :
83+ if default is SENTINEL :
7884 raise KeyError (key )
7985 return default
86+ assert isinstance (e , _CacheEntry )
8087 self ._metrics .inc_hits ()
8188 return e .value
8289
83- def get_with_expiry (self , key ) :
90+ def get_with_expiry (self , key : KT ) -> Tuple [ VT , float , float ] :
8491 """Get a value, and its expiry time, from the cache
8592
8693 Args:
8794 key: key to look up
8895
8996 Returns:
90- Tuple[Any, float, float]: the value from the cache, the expiry time
91- and the TTL
97+ A tuple of the value from the cache, the expiry time and the TTL
9298
9399 Raises:
94100 KeyError if the entry is not found
@@ -102,7 +108,7 @@ def get_with_expiry(self, key):
102108 self ._metrics .inc_hits ()
103109 return e .value , e .expiry_time , e .ttl
104110
105- def pop (self , key , default = SENTINEL ):
111+ def pop (self , key : KT , default : T = SENTINEL ) -> Union [ VT , T ]: # type: ignore
106112 """Remove a value from the cache
107113
108114 If key is in the cache, remove it and return its value, else return default.
@@ -118,29 +124,30 @@ def pop(self, key, default=SENTINEL):
118124 """
119125 self .expire ()
120126 e = self ._data .pop (key , SENTINEL )
121- if e == SENTINEL :
127+ if e is SENTINEL :
122128 self ._metrics .inc_misses ()
123- if default == SENTINEL :
129+ if default is SENTINEL :
124130 raise KeyError (key )
125131 return default
132+ assert isinstance (e , _CacheEntry )
126133 self ._expiry_list .remove (e )
127134 self ._metrics .inc_hits ()
128135 return e .value
129136
130- def __getitem__ (self , key ) :
137+ def __getitem__ (self , key : KT ) -> VT :
131138 return self .get (key )
132139
133- def __delitem__ (self , key ) :
140+ def __delitem__ (self , key : KT ) -> None :
134141 self .pop (key )
135142
136- def __contains__ (self , key ) :
143+ def __contains__ (self , key : KT ) -> bool :
137144 return key in self ._data
138145
139- def __len__ (self ):
146+ def __len__ (self ) -> int :
140147 self .expire ()
141148 return len (self ._data )
142149
143- def expire (self ):
150+ def expire (self ) -> None :
144151 """Run the expiry on the cache. Any entries whose expiry times are due will
145152 be removed
146153 """
@@ -158,7 +165,7 @@ class _CacheEntry:
158165 """TTLCache entry"""
159166
160167 # expiry_time is the first attribute, so that entries are sorted by expiry.
161- expiry_time = attr .ib ()
162- ttl = attr .ib ()
168+ expiry_time = attr .ib (type = float )
169+ ttl = attr .ib (type = float )
163170 key = attr .ib ()
164171 value = attr .ib ()
0 commit comments