Skip to content

Commit 1776679

Browse files
committed
Added support for async API
1 parent fa9bc3c commit 1776679

File tree

6 files changed

+75
-12
lines changed

6 files changed

+75
-12
lines changed

dev_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ uvloop
1616
vulture>=2.3.0
1717
wheel>=0.30.0
1818
numpy>=1.24.0
19+
redispy-entraid-credentials @ git+https://github.com/redis-developer/redispy-entra-credentials.git/@main

redis/asyncio/connection.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import socket
66
import ssl
77
import sys
8+
import threading
89
import warnings
910
import weakref
1011
from abc import abstractmethod
@@ -27,6 +28,7 @@
2728
)
2829
from urllib.parse import ParseResult, parse_qs, unquote, urlparse
2930

31+
from ..event import EventDispatcher, AsyncBeforeCommandExecutionEvent
3032
from ..utils import format_error_message
3133

3234
# the functionality is available in 3.11.x but has a major issue before
@@ -39,7 +41,7 @@
3941
from redis.asyncio.retry import Retry
4042
from redis.backoff import NoBackoff
4143
from redis.connection import DEFAULT_RESP_VERSION
42-
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider
44+
from redis.credentials import CredentialProvider, UsernamePasswordCredentialProvider, StreamingCredentialProvider
4345
from redis.exceptions import (
4446
AuthenticationError,
4547
AuthenticationWrongNumberOfArgsError,
@@ -148,6 +150,7 @@ def __init__(
148150
encoder_class: Type[Encoder] = Encoder,
149151
credential_provider: Optional[CredentialProvider] = None,
150152
protocol: Optional[int] = 2,
153+
event_dispatcher: Optional[EventDispatcher] = EventDispatcher()
151154
):
152155
if (username or password) and credential_provider is not None:
153156
raise DataError(
@@ -195,6 +198,9 @@ def __init__(
195198
self.set_parser(parser_class)
196199
self._connect_callbacks: List[weakref.WeakMethod[ConnectCallbackT]] = []
197200
self._buffer_cutoff = 6000
201+
self._event_dispatcher = event_dispatcher
202+
self._init_auth_args = None
203+
198204
try:
199205
p = int(protocol)
200206
except TypeError:
@@ -333,7 +339,9 @@ async def on_connect(self) -> None:
333339
self.credential_provider
334340
or UsernamePasswordCredentialProvider(self.username, self.password)
335341
)
336-
auth_args = cred_provider.get_credentials()
342+
auth_args = await cred_provider.get_credentials_async()
343+
self._init_auth_args = hash(auth_args)
344+
337345
# if resp version is specified and we have auth args,
338346
# we need to send them via HELLO
339347
if auth_args and self.protocol not in [2, "2"]:
@@ -496,6 +504,10 @@ async def send_packed_command(
496504

497505
async def send_command(self, *args: Any, **kwargs: Any) -> None:
498506
"""Pack and send a command to the Redis server"""
507+
if isinstance(self.credential_provider, StreamingCredentialProvider):
508+
await self._event_dispatcher.dispatch_async(
509+
AsyncBeforeCommandExecutionEvent(args, self._init_auth_args, self, self.credential_provider)
510+
)
499511
await self.send_packed_command(
500512
self.pack_command(*args), check_health=kwargs.get("check_health", True)
501513
)
@@ -1033,6 +1045,7 @@ def __init__(
10331045
self._available_connections: List[AbstractConnection] = []
10341046
self._in_use_connections: Set[AbstractConnection] = set()
10351047
self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder)
1048+
self._lock = threading.Lock()
10361049

10371050
def __repr__(self):
10381051
return (

redis/client.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -314,20 +314,22 @@ def __init__(
314314
"cache_config": cache_config,
315315
}
316316
)
317-
connection_pool = ConnectionPool(**kwargs, event_dispatcher=event_dispatcher)
317+
connection_pool = ConnectionPool(**kwargs)
318318
event_dispatcher.dispatch(AfterPooledConnectionsInstantiationEvent(
319319
[connection_pool],
320+
ClientType.SYNC,
320321
credential_provider
321322
))
322323
self.auto_close_connection_pool = True
323324
else:
324325
self.auto_close_connection_pool = False
326+
event_dispatcher.dispatch(AfterPooledConnectionsInstantiationEvent(
327+
[connection_pool],
328+
ClientType.SYNC,
329+
credential_provider
330+
))
325331

326332
self.connection_pool = connection_pool
327-
event_dispatcher.dispatch(AfterPooledConnectionsInstantiationEvent(
328-
[connection_pool],
329-
credential_provider
330-
))
331333

332334
if (cache_config or cache) and self.connection_pool.get_protocol() not in [
333335
3,

redis/connection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def __init__(
230230
credential_provider: Optional[CredentialProvider] = None,
231231
protocol: Optional[int] = 2,
232232
command_packer: Optional[Callable[[], None]] = None,
233-
event_dispatcher: Optional[EventDispatcherInterface] = EventDispatcher()
233+
event_dispatcher: Optional[EventDispatcher] = EventDispatcher()
234234
):
235235
"""
236236
Initialize a new Connection.

redis/event.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from abc import ABC, abstractmethod
2+
from enum import Enum
23
from typing import List, Union, Optional
34

45
from redis.credentials import StreamingCredentialProvider, CredentialProvider
@@ -13,6 +14,15 @@ def listen(self, event: object):
1314
pass
1415

1516

17+
class AsyncEventListenerInterface(ABC):
18+
"""
19+
Represents an async listener for given event object.
20+
"""
21+
@abstractmethod
22+
async def listen(self, event: object):
23+
pass
24+
25+
1626
class EventDispatcherInterface(ABC):
1727
"""
1828
Represents a dispatcher that dispatches events to listeners associated with given event.
@@ -21,6 +31,10 @@ class EventDispatcherInterface(ABC):
2131
def dispatch(self, event: object):
2232
pass
2333

34+
@abstractmethod
35+
async def dispatch_async(self, event: object):
36+
pass
37+
2438

2539
class EventDispatcher(EventDispatcherInterface):
2640
# TODO: Make dispatcher to accept external mappings.
@@ -34,7 +48,10 @@ def __init__(self):
3448
],
3549
AfterPooledConnectionsInstantiationEvent: [
3650
RegisterReAuthForPooledConnections()
37-
]
51+
],
52+
AsyncBeforeCommandExecutionEvent: [
53+
AsyncReAuthBeforeCommandExecutionListener(),
54+
],
3855
}
3956

4057
def dispatch(self, event: object):
@@ -43,6 +60,12 @@ def dispatch(self, event: object):
4360
for listener in listeners:
4461
listener.listen(event)
4562

63+
async def dispatch_async(self, event: object):
64+
listeners = self._event_listeners_mapping.get(type(event))
65+
66+
for listener in listeners:
67+
await listener.listen(event)
68+
4669

4770
class BeforeCommandExecutionEvent:
4871
"""
@@ -71,6 +94,10 @@ def credential_provider(self) -> StreamingCredentialProvider:
7194
return self._credential_provider
7295

7396

97+
class AsyncBeforeCommandExecutionEvent(BeforeCommandExecutionEvent):
98+
pass
99+
100+
74101
class AfterPooledConnectionsInstantiationEvent:
75102
"""
76103
Event that will be fired after pooled connection instances was created.
@@ -111,14 +138,33 @@ def listen(self, event: BeforeCommandExecutionEvent):
111138
event.connection.read_response()
112139

113140

114-
class RegisterReAuthForPooledConnections(EventListenerInterface):
141+
class AsyncReAuthBeforeCommandExecutionListener(AsyncEventListenerInterface):
142+
"""
143+
Async listener that performs re-authentication (if needed) for StreamingCredentialProviders before command execution
144+
"""
115145
def __init__(self):
116-
self._event = None
146+
self._current_cred = None
147+
148+
async def listen(self, event: AsyncBeforeCommandExecutionEvent):
149+
if self._current_cred is None:
150+
self._current_cred = event.initial_cred
151+
152+
credentials = await event.credential_provider.get_credentials_async()
117153

154+
if hash(credentials) != self._current_cred:
155+
self._current_cred = hash(credentials)
156+
await event.connection.send_command('AUTH', credentials[0], credentials[1])
157+
await event.connection.read_response()
158+
159+
160+
class RegisterReAuthForPooledConnections(EventListenerInterface):
118161
"""
119162
Listener that registers a re-authentication callback for pooled connections.
120163
Required by :class:`StreamingCredentialProvider`.
121164
"""
165+
def __init__(self):
166+
self._event = None
167+
122168
def listen(self, event: AfterPooledConnectionsInstantiationEvent):
123169
if isinstance(event.credential_provider, StreamingCredentialProvider):
124170
self._event = event

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1-
async-timeout>=4.0.3
1+
async-timeout>=4.0.3
2+
redispy-entraid-credentials @ git+https://github.com/redis-developer/redispy-entra-credentials.git/@main

0 commit comments

Comments
 (0)