Skip to content

Commit a9a9f70

Browse files
dvora-hvladvildanov
authored andcommitted
Client side caching invalidations (standalone) (#3089)
* cache invalidations * isort * deamon thread * remove threads * delete comment * tests * skip if hiredis available * async * review comments * docstring * decode test * fix test * fix decode response test
1 parent d3b854d commit a9a9f70

File tree

10 files changed

+457
-57
lines changed

10 files changed

+457
-57
lines changed

redis/__init__.py

-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from redis import asyncio # noqa
44
from redis.backoff import default_backoff
5-
from redis.cache import _LocalChace
65
from redis.client import Redis, StrictRedis
76
from redis.cluster import RedisCluster
87
from redis.connection import (
@@ -62,7 +61,6 @@ def int_or_str(value):
6261
VERSION = tuple([99, 99, 99])
6362

6463
__all__ = [
65-
"_LocalChace",
6664
"AuthenticationError",
6765
"AuthenticationWrongNumberOfArgsError",
6866
"BlockingConnectionPool",

redis/_parsers/resp3.py

+44-22
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,18 @@
66
from .base import _AsyncRESPBase, _RESPBase
77
from .socket import SERVER_CLOSED_CONNECTION_ERROR
88

9+
_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]
10+
911

1012
class _RESP3Parser(_RESPBase):
1113
"""RESP3 protocol implementation"""
1214

1315
def __init__(self, socket_read_size):
1416
super().__init__(socket_read_size)
15-
self.push_handler_func = self.handle_push_response
17+
self.pubsub_push_handler_func = self.handle_pubsub_push_response
18+
self.invalidations_push_handler_func = None
1619

17-
def handle_push_response(self, response):
20+
def handle_pubsub_push_response(self, response):
1821
logger = getLogger("push_response")
1922
logger.info("Push response: " + str(response))
2023
return response
@@ -114,30 +117,40 @@ def _read_response(self, disable_decoding=False, push_request=False):
114117
)
115118
for _ in range(int(response))
116119
]
117-
res = self.push_handler_func(response)
118-
if not push_request:
119-
return self._read_response(
120-
disable_decoding=disable_decoding, push_request=push_request
121-
)
122-
else:
123-
return res
120+
self.handle_push_response(response, disable_decoding, push_request)
124121
else:
125122
raise InvalidResponse(f"Protocol Error: {raw!r}")
126123

127124
if isinstance(response, bytes) and disable_decoding is False:
128125
response = self.encoder.decode(response)
129126
return response
130127

131-
def set_push_handler(self, push_handler_func):
132-
self.push_handler_func = push_handler_func
128+
def handle_push_response(self, response, disable_decoding, push_request):
129+
if response[0] in _INVALIDATION_MESSAGE:
130+
res = self.invalidation_push_handler_func(response)
131+
else:
132+
res = self.pubsub_push_handler_func(response)
133+
if not push_request:
134+
return self._read_response(
135+
disable_decoding=disable_decoding, push_request=push_request
136+
)
137+
else:
138+
return res
139+
140+
def set_pubsub_push_handler(self, pubsub_push_handler_func):
141+
self.pubsub_push_handler_func = pubsub_push_handler_func
142+
143+
def set_invalidation_push_handler(self, invalidations_push_handler_func):
144+
self.invalidation_push_handler_func = invalidations_push_handler_func
133145

134146

135147
class _AsyncRESP3Parser(_AsyncRESPBase):
136148
def __init__(self, socket_read_size):
137149
super().__init__(socket_read_size)
138-
self.push_handler_func = self.handle_push_response
150+
self.pubsub_push_handler_func = self.handle_pubsub_push_response
151+
self.invalidations_push_handler_func = None
139152

140-
def handle_push_response(self, response):
153+
def handle_pubsub_push_response(self, response):
141154
logger = getLogger("push_response")
142155
logger.info("Push response: " + str(response))
143156
return response
@@ -246,19 +259,28 @@ async def _read_response(
246259
)
247260
for _ in range(int(response))
248261
]
249-
res = self.push_handler_func(response)
250-
if not push_request:
251-
return await self._read_response(
252-
disable_decoding=disable_decoding, push_request=push_request
253-
)
254-
else:
255-
return res
262+
await self.handle_push_response(response, disable_decoding, push_request)
256263
else:
257264
raise InvalidResponse(f"Protocol Error: {raw!r}")
258265

259266
if isinstance(response, bytes) and disable_decoding is False:
260267
response = self.encoder.decode(response)
261268
return response
262269

263-
def set_push_handler(self, push_handler_func):
264-
self.push_handler_func = push_handler_func
270+
async def handle_push_response(self, response, disable_decoding, push_request):
271+
if response[0] in _INVALIDATION_MESSAGE:
272+
res = self.invalidation_push_handler_func(response)
273+
else:
274+
res = self.pubsub_push_handler_func(response)
275+
if not push_request:
276+
return await self._read_response(
277+
disable_decoding=disable_decoding, push_request=push_request
278+
)
279+
else:
280+
return res
281+
282+
def set_pubsub_push_handler(self, pubsub_push_handler_func):
283+
self.pubsub_push_handler_func = pubsub_push_handler_func
284+
285+
def set_invalidation_push_handler(self, invalidations_push_handler_func):
286+
self.invalidation_push_handler_func = invalidations_push_handler_func

redis/asyncio/client.py

+113-18
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
)
3838
from redis.asyncio.lock import Lock
3939
from redis.asyncio.retry import Retry
40+
from redis.cache import (
41+
DEFAULT_BLACKLIST,
42+
DEFAULT_EVICTION_POLICY,
43+
DEFAULT_WHITELIST,
44+
_LocalCache,
45+
)
4046
from redis.client import (
4147
EMPTY_RESPONSE,
4248
NEVER_DECODE,
@@ -60,7 +66,7 @@
6066
TimeoutError,
6167
WatchError,
6268
)
63-
from redis.typing import ChannelT, EncodableT, KeyT
69+
from redis.typing import ChannelT, EncodableT, KeysT, KeyT, ResponseT
6470
from redis.utils import (
6571
HIREDIS_AVAILABLE,
6672
_set_info_logger,
@@ -231,6 +237,13 @@ def __init__(
231237
redis_connect_func=None,
232238
credential_provider: Optional[CredentialProvider] = None,
233239
protocol: Optional[int] = 2,
240+
cache_enable: bool = False,
241+
client_cache: Optional[_LocalCache] = None,
242+
cache_max_size: int = 100,
243+
cache_ttl: int = 0,
244+
cache_eviction_policy: str = DEFAULT_EVICTION_POLICY,
245+
cache_blacklist: List[str] = DEFAULT_BLACKLIST,
246+
cache_whitelist: List[str] = DEFAULT_WHITELIST,
234247
):
235248
"""
236249
Initialize a new Redis client.
@@ -336,6 +349,16 @@ def __init__(
336349
# on a set of redis commands
337350
self._single_conn_lock = asyncio.Lock()
338351

352+
self.client_cache = client_cache
353+
if cache_enable:
354+
self.client_cache = _LocalCache(
355+
cache_max_size, cache_ttl, cache_eviction_policy
356+
)
357+
if self.client_cache is not None:
358+
self.cache_blacklist = cache_blacklist
359+
self.cache_whitelist = cache_whitelist
360+
self.client_cache_initialized = False
361+
339362
def __repr__(self):
340363
return f"{self.__class__.__name__}<{self.connection_pool!r}>"
341364

@@ -347,6 +370,10 @@ async def initialize(self: _RedisT) -> _RedisT:
347370
async with self._single_conn_lock:
348371
if self.connection is None:
349372
self.connection = await self.connection_pool.get_connection("_")
373+
if self.client_cache is not None:
374+
self.connection._parser.set_invalidation_push_handler(
375+
self._cache_invalidation_process
376+
)
350377
return self
351378

352379
def set_response_callback(self, command: str, callback: ResponseCallbackT):
@@ -565,6 +592,8 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
565592
close_connection_pool is None and self.auto_close_connection_pool
566593
):
567594
await self.connection_pool.disconnect()
595+
if self.client_cache:
596+
self.client_cache.flush()
568597

569598
@deprecated_function(version="5.0.1", reason="Use aclose() instead", name="close")
570599
async def close(self, close_connection_pool: Optional[bool] = None) -> None:
@@ -593,29 +622,95 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
593622
):
594623
raise error
595624

625+
def _cache_invalidation_process(
626+
self, data: List[Union[str, Optional[List[str]]]]
627+
) -> None:
628+
"""
629+
Invalidate (delete) all redis commands associated with a specific key.
630+
`data` is a list of strings, where the first string is the invalidation message
631+
and the second string is the list of keys to invalidate.
632+
(if the list of keys is None, then all keys are invalidated)
633+
"""
634+
if data[1] is not None:
635+
for key in data[1]:
636+
self.client_cache.invalidate(str_if_bytes(key))
637+
else:
638+
self.client_cache.flush()
639+
640+
async def _get_from_local_cache(self, command: str):
641+
"""
642+
If the command is in the local cache, return the response
643+
"""
644+
if (
645+
self.client_cache is None
646+
or command[0] in self.cache_blacklist
647+
or command[0] not in self.cache_whitelist
648+
):
649+
return None
650+
while not self.connection._is_socket_empty():
651+
await self.connection.read_response(push_request=True)
652+
return self.client_cache.get(command)
653+
654+
def _add_to_local_cache(
655+
self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
656+
):
657+
"""
658+
Add the command and response to the local cache if the command
659+
is allowed to be cached
660+
"""
661+
if (
662+
self.client_cache is not None
663+
and (self.cache_blacklist == [] or command[0] not in self.cache_blacklist)
664+
and (self.cache_whitelist == [] or command[0] in self.cache_whitelist)
665+
):
666+
self.client_cache.set(command, response, keys)
667+
668+
def delete_from_local_cache(self, command: str):
669+
"""
670+
Delete the command from the local cache
671+
"""
672+
try:
673+
self.client_cache.delete(command)
674+
except AttributeError:
675+
pass
676+
596677
# COMMAND EXECUTION AND PROTOCOL PARSING
597678
async def execute_command(self, *args, **options):
598679
"""Execute a command and return a parsed response"""
599680
await self.initialize()
600-
options.pop("keys", None) # the keys are used only for client side caching
601-
pool = self.connection_pool
602681
command_name = args[0]
603-
conn = self.connection or await pool.get_connection(command_name, **options)
682+
keys = options.pop("keys", None) # keys are used only for client side caching
683+
response_from_cache = await self._get_from_local_cache(args)
684+
if response_from_cache is not None:
685+
return response_from_cache
686+
else:
687+
pool = self.connection_pool
688+
conn = self.connection or await pool.get_connection(command_name, **options)
604689

605-
if self.single_connection_client:
606-
await self._single_conn_lock.acquire()
607-
try:
608-
return await conn.retry.call_with_retry(
609-
lambda: self._send_command_parse_response(
610-
conn, command_name, *args, **options
611-
),
612-
lambda error: self._disconnect_raise(conn, error),
613-
)
614-
finally:
615690
if self.single_connection_client:
616-
self._single_conn_lock.release()
617-
if not self.connection:
618-
await pool.release(conn)
691+
await self._single_conn_lock.acquire()
692+
try:
693+
if self.client_cache is not None and not self.client_cache_initialized:
694+
await conn.retry.call_with_retry(
695+
lambda: self._send_command_parse_response(
696+
conn, "CLIENT", *("CLIENT", "TRACKING", "ON")
697+
),
698+
lambda error: self._disconnect_raise(conn, error),
699+
)
700+
self.client_cache_initialized = True
701+
response = await conn.retry.call_with_retry(
702+
lambda: self._send_command_parse_response(
703+
conn, command_name, *args, **options
704+
),
705+
lambda error: self._disconnect_raise(conn, error),
706+
)
707+
self._add_to_local_cache(args, response, keys)
708+
return response
709+
finally:
710+
if self.single_connection_client:
711+
self._single_conn_lock.release()
712+
if not self.connection:
713+
await pool.release(conn)
619714

620715
async def parse_response(
621716
self, connection: Connection, command_name: Union[str, bytes], **options
@@ -863,7 +958,7 @@ async def connect(self):
863958
else:
864959
await self.connection.connect()
865960
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
866-
self.connection._parser.set_push_handler(self.push_handler_func)
961+
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
867962

868963
async def _disconnect_raise_connect(self, conn, error):
869964
"""

redis/asyncio/connection.py

+4
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,10 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]
645645
output.append(SYM_EMPTY.join(pieces))
646646
return output
647647

648+
def _is_socket_empty(self):
649+
"""Check if the socket is empty"""
650+
return not self._reader.at_eof()
651+
648652

649653
class Connection(AbstractConnection):
650654
"Manages TCP communication to and from a Redis server"

redis/cache.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ class EvictionPolicy(Enum):
159159
RANDOM = "random"
160160

161161

162-
class _LocalChace:
162+
class _LocalCache:
163163
"""
164164
A caching mechanism for storing redis commands and their responses.
165165
@@ -220,6 +220,7 @@ def get(self, command: str) -> ResponseT:
220220
if command in self.cache:
221221
if self._is_expired(command):
222222
self.delete(command)
223+
return
223224
self._update_access(command)
224225
return self.cache[command]["response"]
225226

@@ -266,28 +267,28 @@ def _update_access(self, command: str):
266267
Args:
267268
command (str): The redis command.
268269
"""
269-
if self.eviction_policy == EvictionPolicy.LRU:
270+
if self.eviction_policy == EvictionPolicy.LRU.value:
270271
self.cache.move_to_end(command)
271-
elif self.eviction_policy == EvictionPolicy.LFU:
272+
elif self.eviction_policy == EvictionPolicy.LFU.value:
272273
self.cache[command]["access_count"] = (
273274
self.cache.get(command, {}).get("access_count", 0) + 1
274275
)
275276
self.cache.move_to_end(command)
276-
elif self.eviction_policy == EvictionPolicy.RANDOM:
277+
elif self.eviction_policy == EvictionPolicy.RANDOM.value:
277278
pass # Random eviction doesn't require updates
278279

279280
def _evict(self):
280281
"""Evict a redis command from the cache based on the eviction policy."""
281282
if self._is_expired(self.commands_ttl_list[0]):
282283
self.delete(self.commands_ttl_list[0])
283-
elif self.eviction_policy == EvictionPolicy.LRU:
284+
elif self.eviction_policy == EvictionPolicy.LRU.value:
284285
self.cache.popitem(last=False)
285-
elif self.eviction_policy == EvictionPolicy.LFU:
286+
elif self.eviction_policy == EvictionPolicy.LFU.value:
286287
min_access_command = min(
287288
self.cache, key=lambda k: self.cache[k].get("access_count", 0)
288289
)
289290
self.cache.pop(min_access_command)
290-
elif self.eviction_policy == EvictionPolicy.RANDOM:
291+
elif self.eviction_policy == EvictionPolicy.RANDOM.value:
291292
random_command = random.choice(list(self.cache.keys()))
292293
self.cache.pop(random_command)
293294

@@ -322,5 +323,6 @@ def invalidate(self, key: KeyT):
322323
"""
323324
if key not in self.key_commands_map:
324325
return
325-
for command in self.key_commands_map[key]:
326+
commands = list(self.key_commands_map[key])
327+
for command in commands:
326328
self.delete(command)

0 commit comments

Comments
 (0)