37
37
)
38
38
from redis .asyncio .lock import Lock
39
39
from redis .asyncio .retry import Retry
40
+ from redis .cache import (
41
+ DEFAULT_BLACKLIST ,
42
+ DEFAULT_EVICTION_POLICY ,
43
+ DEFAULT_WHITELIST ,
44
+ _LocalCache ,
45
+ )
40
46
from redis .client import (
41
47
EMPTY_RESPONSE ,
42
48
NEVER_DECODE ,
60
66
TimeoutError ,
61
67
WatchError ,
62
68
)
63
- from redis .typing import ChannelT , EncodableT , KeyT
69
+ from redis .typing import ChannelT , EncodableT , KeysT , KeyT , ResponseT
64
70
from redis .utils import (
65
71
HIREDIS_AVAILABLE ,
66
72
_set_info_logger ,
@@ -231,6 +237,13 @@ def __init__(
231
237
redis_connect_func = None ,
232
238
credential_provider : Optional [CredentialProvider ] = None ,
233
239
protocol : Optional [int ] = 2 ,
240
+ cache_enable : bool = False ,
241
+ client_cache : Optional [_LocalCache ] = None ,
242
+ cache_max_size : int = 100 ,
243
+ cache_ttl : int = 0 ,
244
+ cache_eviction_policy : str = DEFAULT_EVICTION_POLICY ,
245
+ cache_blacklist : List [str ] = DEFAULT_BLACKLIST ,
246
+ cache_whitelist : List [str ] = DEFAULT_WHITELIST ,
234
247
):
235
248
"""
236
249
Initialize a new Redis client.
@@ -336,6 +349,16 @@ def __init__(
336
349
# on a set of redis commands
337
350
self ._single_conn_lock = asyncio .Lock ()
338
351
352
+ self .client_cache = client_cache
353
+ if cache_enable :
354
+ self .client_cache = _LocalCache (
355
+ cache_max_size , cache_ttl , cache_eviction_policy
356
+ )
357
+ if self .client_cache is not None :
358
+ self .cache_blacklist = cache_blacklist
359
+ self .cache_whitelist = cache_whitelist
360
+ self .client_cache_initialized = False
361
+
339
362
def __repr__ (self ):
340
363
return f"{ self .__class__ .__name__ } <{ self .connection_pool !r} >"
341
364
@@ -347,6 +370,10 @@ async def initialize(self: _RedisT) -> _RedisT:
347
370
async with self ._single_conn_lock :
348
371
if self .connection is None :
349
372
self .connection = await self .connection_pool .get_connection ("_" )
373
+ if self .client_cache is not None :
374
+ self .connection ._parser .set_invalidation_push_handler (
375
+ self ._cache_invalidation_process
376
+ )
350
377
return self
351
378
352
379
def set_response_callback (self , command : str , callback : ResponseCallbackT ):
@@ -565,6 +592,8 @@ async def aclose(self, close_connection_pool: Optional[bool] = None) -> None:
565
592
close_connection_pool is None and self .auto_close_connection_pool
566
593
):
567
594
await self .connection_pool .disconnect ()
595
+ if self .client_cache :
596
+ self .client_cache .flush ()
568
597
569
598
@deprecated_function (version = "5.0.1" , reason = "Use aclose() instead" , name = "close" )
570
599
async def close (self , close_connection_pool : Optional [bool ] = None ) -> None :
@@ -593,29 +622,95 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
593
622
):
594
623
raise error
595
624
625
+ def _cache_invalidation_process (
626
+ self , data : List [Union [str , Optional [List [str ]]]]
627
+ ) -> None :
628
+ """
629
+ Invalidate (delete) all redis commands associated with a specific key.
630
+ `data` is a list of strings, where the first string is the invalidation message
631
+ and the second string is the list of keys to invalidate.
632
+ (if the list of keys is None, then all keys are invalidated)
633
+ """
634
+ if data [1 ] is not None :
635
+ for key in data [1 ]:
636
+ self .client_cache .invalidate (str_if_bytes (key ))
637
+ else :
638
+ self .client_cache .flush ()
639
+
640
+ async def _get_from_local_cache (self , command : str ):
641
+ """
642
+ If the command is in the local cache, return the response
643
+ """
644
+ if (
645
+ self .client_cache is None
646
+ or command [0 ] in self .cache_blacklist
647
+ or command [0 ] not in self .cache_whitelist
648
+ ):
649
+ return None
650
+ while not self .connection ._is_socket_empty ():
651
+ await self .connection .read_response (push_request = True )
652
+ return self .client_cache .get (command )
653
+
654
+ def _add_to_local_cache (
655
+ self , command : Tuple [str ], response : ResponseT , keys : List [KeysT ]
656
+ ):
657
+ """
658
+ Add the command and response to the local cache if the command
659
+ is allowed to be cached
660
+ """
661
+ if (
662
+ self .client_cache is not None
663
+ and (self .cache_blacklist == [] or command [0 ] not in self .cache_blacklist )
664
+ and (self .cache_whitelist == [] or command [0 ] in self .cache_whitelist )
665
+ ):
666
+ self .client_cache .set (command , response , keys )
667
+
668
+ def delete_from_local_cache (self , command : str ):
669
+ """
670
+ Delete the command from the local cache
671
+ """
672
+ try :
673
+ self .client_cache .delete (command )
674
+ except AttributeError :
675
+ pass
676
+
596
677
# COMMAND EXECUTION AND PROTOCOL PARSING
597
678
async def execute_command (self , * args , ** options ):
598
679
"""Execute a command and return a parsed response"""
599
680
await self .initialize ()
600
- options .pop ("keys" , None ) # the keys are used only for client side caching
601
- pool = self .connection_pool
602
681
command_name = args [0 ]
603
- conn = self .connection or await pool .get_connection (command_name , ** options )
682
+ keys = options .pop ("keys" , None ) # keys are used only for client side caching
683
+ response_from_cache = await self ._get_from_local_cache (args )
684
+ if response_from_cache is not None :
685
+ return response_from_cache
686
+ else :
687
+ pool = self .connection_pool
688
+ conn = self .connection or await pool .get_connection (command_name , ** options )
604
689
605
- if self .single_connection_client :
606
- await self ._single_conn_lock .acquire ()
607
- try :
608
- return await conn .retry .call_with_retry (
609
- lambda : self ._send_command_parse_response (
610
- conn , command_name , * args , ** options
611
- ),
612
- lambda error : self ._disconnect_raise (conn , error ),
613
- )
614
- finally :
615
690
if self .single_connection_client :
616
- self ._single_conn_lock .release ()
617
- if not self .connection :
618
- await pool .release (conn )
691
+ await self ._single_conn_lock .acquire ()
692
+ try :
693
+ if self .client_cache is not None and not self .client_cache_initialized :
694
+ await conn .retry .call_with_retry (
695
+ lambda : self ._send_command_parse_response (
696
+ conn , "CLIENT" , * ("CLIENT" , "TRACKING" , "ON" )
697
+ ),
698
+ lambda error : self ._disconnect_raise (conn , error ),
699
+ )
700
+ self .client_cache_initialized = True
701
+ response = await conn .retry .call_with_retry (
702
+ lambda : self ._send_command_parse_response (
703
+ conn , command_name , * args , ** options
704
+ ),
705
+ lambda error : self ._disconnect_raise (conn , error ),
706
+ )
707
+ self ._add_to_local_cache (args , response , keys )
708
+ return response
709
+ finally :
710
+ if self .single_connection_client :
711
+ self ._single_conn_lock .release ()
712
+ if not self .connection :
713
+ await pool .release (conn )
619
714
620
715
async def parse_response (
621
716
self , connection : Connection , command_name : Union [str , bytes ], ** options
@@ -863,7 +958,7 @@ async def connect(self):
863
958
else :
864
959
await self .connection .connect ()
865
960
if self .push_handler_func is not None and not HIREDIS_AVAILABLE :
866
- self .connection ._parser .set_push_handler (self .push_handler_func )
961
+ self .connection ._parser .set_pubsub_push_handler (self .push_handler_func )
867
962
868
963
async def _disconnect_raise_connect (self , conn , error ):
869
964
"""
0 commit comments