Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit 01dd90b

Browse files
authored
Add type hints to DictionaryCache and TTLCache. (#9442)
1 parent 7dcf3fd commit 01dd90b

File tree

7 files changed

+96
-67
lines changed

7 files changed

+96
-67
lines changed

changelog.d/9442.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add type hints to the caching module.

synapse/http/federation/well_known_resolver.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,10 @@
7171
logger = logging.getLogger(__name__)
7272

7373

74-
_well_known_cache = TTLCache("well-known")
75-
_had_valid_well_known_cache = TTLCache("had-valid-well-known")
74+
_well_known_cache = TTLCache("well-known") # type: TTLCache[bytes, Optional[bytes]]
75+
_had_valid_well_known_cache = TTLCache(
76+
"had-valid-well-known"
77+
) # type: TTLCache[bytes, bool]
7678

7779

7880
@attr.s(slots=True, frozen=True)
@@ -88,8 +90,8 @@ def __init__(
8890
reactor: IReactorTime,
8991
agent: IAgent,
9092
user_agent: bytes,
91-
well_known_cache: Optional[TTLCache] = None,
92-
had_well_known_cache: Optional[TTLCache] = None,
93+
well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None,
94+
had_well_known_cache: Optional[TTLCache[bytes, bool]] = None,
9395
):
9496
self._reactor = reactor
9597
self._clock = Clock(reactor)

synapse/storage/databases/state/store.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,13 @@ def _get_state_for_group_using_cache(self, cache, group, state_filter):
183183
requests state from the cache, if False we need to query the DB for the
184184
missing state.
185185
"""
186-
is_all, known_absent, state_dict_ids = cache.get(group)
186+
cache_entry = cache.get(group)
187+
state_dict_ids = cache_entry.value
187188

188-
if is_all or state_filter.is_full():
189+
if cache_entry.full or state_filter.is_full():
189190
# Either we have everything or want everything, either way
190191
# `is_all` tells us whether we've gotten everything.
191-
return state_filter.filter_state(state_dict_ids), is_all
192+
return state_filter.filter_state(state_dict_ids), cache_entry.full
192193

193194
# tracks whether any of our requested types are missing from the cache
194195
missing_types = False
@@ -202,7 +203,7 @@ def _get_state_for_group_using_cache(self, cache, group, state_filter):
202203
# There aren't any wild cards, so `concrete_types()` returns the
203204
# complete list of event types we're wanting.
204205
for key in state_filter.concrete_types():
205-
if key not in state_dict_ids and key not in known_absent:
206+
if key not in state_dict_ids and key not in cache_entry.known_absent:
206207
missing_types = True
207208
break
208209

synapse/util/caches/dictionary_cache.py

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,26 +15,38 @@
1515
import enum
1616
import logging
1717
import threading
18-
from collections import namedtuple
19-
from typing import Any
18+
from typing import Any, Dict, Generic, Iterable, Optional, Set, TypeVar
19+
20+
import attr
2021

2122
from synapse.util.caches.lrucache import LruCache
2223

2324
logger = logging.getLogger(__name__)
2425

2526

26-
class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
27+
# The type of the cache keys.
28+
KT = TypeVar("KT")
29+
# The type of the dictionary keys.
30+
DKT = TypeVar("DKT")
31+
32+
33+
@attr.s(slots=True)
34+
class DictionaryEntry:
2735
"""Returned when getting an entry from the cache
2836
2937
Attributes:
30-
full (bool): Whether the cache has the full or dict or just some keys.
38+
full: Whether the cache has the full or dict or just some keys.
3139
If not full then not all requested keys will necessarily be present
3240
in `value`
33-
known_absent (set): Keys that were looked up in the dict and were not
41+
known_absent: Keys that were looked up in the dict and were not
3442
there.
35-
value (dict): The full or partial dict value
43+
value: The full or partial dict value
3644
"""
3745

46+
full = attr.ib(type=bool)
47+
known_absent = attr.ib()
48+
value = attr.ib()
49+
3850
def __len__(self):
3951
return len(self.value)
4052

@@ -45,21 +57,21 @@ class _Sentinel(enum.Enum):
4557
sentinel = object()
4658

4759

48-
class DictionaryCache:
60+
class DictionaryCache(Generic[KT, DKT]):
4961
"""Caches key -> dictionary lookups, supporting caching partial dicts, i.e.
5062
fetching a subset of dictionary keys for a particular key.
5163
"""
5264

53-
def __init__(self, name, max_entries=1000):
65+
def __init__(self, name: str, max_entries: int = 1000):
5466
self.cache = LruCache(
5567
max_size=max_entries, cache_name=name, size_callback=len
56-
) # type: LruCache[Any, DictionaryEntry]
68+
) # type: LruCache[KT, DictionaryEntry]
5769

5870
self.name = name
5971
self.sequence = 0
60-
self.thread = None
72+
self.thread = None # type: Optional[threading.Thread]
6173

62-
def check_thread(self):
74+
def check_thread(self) -> None:
6375
expected_thread = self.thread
6476
if expected_thread is None:
6577
self.thread = threading.current_thread()
@@ -69,12 +81,14 @@ def check_thread(self):
6981
"Cache objects can only be accessed from the main thread"
7082
)
7183

72-
def get(self, key, dict_keys=None):
84+
def get(
85+
self, key: KT, dict_keys: Optional[Iterable[DKT]] = None
86+
) -> DictionaryEntry:
7387
"""Fetch an entry out of the cache
7488
7589
Args:
7690
key
77-
dict_key(list): If given a set of keys then return only those keys
91+
dict_key: If given a set of keys then return only those keys
7892
that exist in the cache.
7993
8094
Returns:
@@ -95,27 +109,33 @@ def get(self, key, dict_keys=None):
95109

96110
return DictionaryEntry(False, set(), {})
97111

98-
def invalidate(self, key):
112+
def invalidate(self, key: KT) -> None:
99113
self.check_thread()
100114

101115
# Increment the sequence number so that any SELECT statements that
102116
# raced with the INSERT don't update the cache (SYN-369)
103117
self.sequence += 1
104118
self.cache.pop(key, None)
105119

106-
def invalidate_all(self):
120+
def invalidate_all(self) -> None:
107121
self.check_thread()
108122
self.sequence += 1
109123
self.cache.clear()
110124

111-
def update(self, sequence, key, value, fetched_keys=None):
125+
def update(
126+
self,
127+
sequence: int,
128+
key: KT,
129+
value: Dict[DKT, Any],
130+
fetched_keys: Optional[Set[DKT]] = None,
131+
) -> None:
112132
"""Updates the entry in the cache
113133
114134
Args:
115135
sequence
116-
key (K)
117-
value (dict[X,Y]): The value to update the cache with.
118-
fetched_keys (None|set[X]): All of the dictionary keys which were
136+
key
137+
value: The value to update the cache with.
138+
fetched_keys: All of the dictionary keys which were
119139
fetched from the database.
120140
121141
If None, this is the complete value for key K. Otherwise, it
@@ -131,7 +151,9 @@ def update(self, sequence, key, value, fetched_keys=None):
131151
else:
132152
self._update_or_insert(key, value, fetched_keys)
133153

134-
def _update_or_insert(self, key, value, known_absent):
154+
def _update_or_insert(
155+
self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]
156+
) -> None:
135157
# We pop and reinsert as we need to tell the cache the size may have
136158
# changed
137159

@@ -140,5 +162,5 @@ def _update_or_insert(self, key, value, known_absent):
140162
entry.known_absent.update(known_absent)
141163
self.cache[key] = entry
142164

143-
def _insert(self, key, value, known_absent):
165+
def _insert(self, key: KT, value: Dict[DKT, Any], known_absent: Set[DKT]) -> None:
144166
self.cache[key] = DictionaryEntry(True, known_absent, value)

synapse/util/caches/ttlcache.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import logging
1717
import time
18+
from typing import Any, Callable, Dict, Generic, Tuple, TypeVar, Union
1819

1920
import attr
2021
from sortedcontainers import SortedList
@@ -23,15 +24,19 @@
2324

2425
logger = logging.getLogger(__name__)
2526

26-
SENTINEL = object()
27+
SENTINEL = object() # type: Any
2728

29+
T = TypeVar("T")
30+
KT = TypeVar("KT")
31+
VT = TypeVar("VT")
2832

29-
class TTLCache:
33+
34+
class TTLCache(Generic[KT, VT]):
3035
"""A key/value cache implementation where each entry has its own TTL"""
3136

32-
def __init__(self, cache_name, timer=time.time):
37+
def __init__(self, cache_name: str, timer: Callable[[], float] = time.time):
3338
# map from key to _CacheEntry
34-
self._data = {}
39+
self._data = {} # type: Dict[KT, _CacheEntry]
3540

3641
# the _CacheEntries, sorted by expiry time
3742
self._expiry_list = SortedList() # type: SortedList[_CacheEntry]
@@ -40,26 +45,27 @@ def __init__(self, cache_name, timer=time.time):
4045

4146
self._metrics = register_cache("ttl", cache_name, self, resizable=False)
4247

43-
def set(self, key, value, ttl):
48+
def set(self, key: KT, value: VT, ttl: float) -> None:
4449
"""Add/update an entry in the cache
4550
4651
Args:
4752
key: key for this entry
4853
value: value for this entry
49-
ttl (float): TTL for this entry, in seconds
54+
ttl: TTL for this entry, in seconds
5055
"""
5156
expiry = self._timer() + ttl
5257

5358
self.expire()
5459
e = self._data.pop(key, SENTINEL)
55-
if e != SENTINEL:
60+
if e is not SENTINEL:
61+
assert isinstance(e, _CacheEntry)
5662
self._expiry_list.remove(e)
5763

5864
entry = _CacheEntry(expiry_time=expiry, ttl=ttl, key=key, value=value)
5965
self._data[key] = entry
6066
self._expiry_list.add(entry)
6167

62-
def get(self, key, default=SENTINEL):
68+
def get(self, key: KT, default: T = SENTINEL) -> Union[VT, T]:
6369
"""Get a value from the cache
6470
6571
Args:
@@ -72,23 +78,23 @@ def get(self, key, default=SENTINEL):
7278
"""
7379
self.expire()
7480
e = self._data.get(key, SENTINEL)
75-
if e == SENTINEL:
81+
if e is SENTINEL:
7682
self._metrics.inc_misses()
77-
if default == SENTINEL:
83+
if default is SENTINEL:
7884
raise KeyError(key)
7985
return default
86+
assert isinstance(e, _CacheEntry)
8087
self._metrics.inc_hits()
8188
return e.value
8289

83-
def get_with_expiry(self, key):
90+
def get_with_expiry(self, key: KT) -> Tuple[VT, float, float]:
8491
"""Get a value, and its expiry time, from the cache
8592
8693
Args:
8794
key: key to look up
8895
8996
Returns:
90-
Tuple[Any, float, float]: the value from the cache, the expiry time
91-
and the TTL
97+
A tuple of the value from the cache, the expiry time and the TTL
9298
9399
Raises:
94100
KeyError if the entry is not found
@@ -102,7 +108,7 @@ def get_with_expiry(self, key):
102108
self._metrics.inc_hits()
103109
return e.value, e.expiry_time, e.ttl
104110

105-
def pop(self, key, default=SENTINEL):
111+
def pop(self, key: KT, default: T = SENTINEL) -> Union[VT, T]: # type: ignore
106112
"""Remove a value from the cache
107113
108114
If key is in the cache, remove it and return its value, else return default.
@@ -118,29 +124,30 @@ def pop(self, key, default=SENTINEL):
118124
"""
119125
self.expire()
120126
e = self._data.pop(key, SENTINEL)
121-
if e == SENTINEL:
127+
if e is SENTINEL:
122128
self._metrics.inc_misses()
123-
if default == SENTINEL:
129+
if default is SENTINEL:
124130
raise KeyError(key)
125131
return default
132+
assert isinstance(e, _CacheEntry)
126133
self._expiry_list.remove(e)
127134
self._metrics.inc_hits()
128135
return e.value
129136

130-
def __getitem__(self, key):
137+
def __getitem__(self, key: KT) -> VT:
131138
return self.get(key)
132139

133-
def __delitem__(self, key):
140+
def __delitem__(self, key: KT) -> None:
134141
self.pop(key)
135142

136-
def __contains__(self, key):
143+
def __contains__(self, key: KT) -> bool:
137144
return key in self._data
138145

139-
def __len__(self):
146+
def __len__(self) -> int:
140147
self.expire()
141148
return len(self._data)
142149

143-
def expire(self):
150+
def expire(self) -> None:
144151
"""Run the expiry on the cache. Any entries whose expiry times are due will
145152
be removed
146153
"""
@@ -158,7 +165,7 @@ class _CacheEntry:
158165
"""TTLCache entry"""
159166

160167
# expiry_time is the first attribute, so that entries are sorted by expiry.
161-
expiry_time = attr.ib()
162-
ttl = attr.ib()
168+
expiry_time = attr.ib(type=float)
169+
ttl = attr.ib(type=float)
163170
key = attr.ib()
164171
value = attr.ib()

tests/storage/test_state.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -377,14 +377,11 @@ def test_get_state_for_event(self):
377377
#######################################################
378378
# deliberately remove e2 (room name) from the _state_group_cache
379379

380-
(
381-
is_all,
382-
known_absent,
383-
state_dict_ids,
384-
) = self.state_datastore._state_group_cache.get(group)
380+
cache_entry = self.state_datastore._state_group_cache.get(group)
381+
state_dict_ids = cache_entry.value
385382

386-
self.assertEqual(is_all, True)
387-
self.assertEqual(known_absent, set())
383+
self.assertEqual(cache_entry.full, True)
384+
self.assertEqual(cache_entry.known_absent, set())
388385
self.assertDictEqual(
389386
state_dict_ids,
390387
{
@@ -403,14 +400,11 @@ def test_get_state_for_event(self):
403400
fetched_keys=((e1.type, e1.state_key),),
404401
)
405402

406-
(
407-
is_all,
408-
known_absent,
409-
state_dict_ids,
410-
) = self.state_datastore._state_group_cache.get(group)
403+
cache_entry = self.state_datastore._state_group_cache.get(group)
404+
state_dict_ids = cache_entry.value
411405

412-
self.assertEqual(is_all, False)
413-
self.assertEqual(known_absent, {(e1.type, e1.state_key)})
406+
self.assertEqual(cache_entry.full, False)
407+
self.assertEqual(cache_entry.known_absent, {(e1.type, e1.state_key)})
414408
self.assertDictEqual(state_dict_ids, {(e1.type, e1.state_key): e1.event_id})
415409

416410
############################################

tests/util/test_dict_cache.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ def test_simple_cache_hit_full(self):
2727
key = "test_simple_cache_hit_full"
2828

2929
v = self.cache.get(key)
30-
self.assertEqual((False, set(), {}), v)
30+
self.assertIs(v.full, False)
31+
self.assertEqual(v.known_absent, set())
32+
self.assertEqual({}, v.value)
3133

3234
seq = self.cache.sequence
3335
test_value = {"test": "test_simple_cache_hit_full"}

0 commit comments

Comments
 (0)