Skip to content

Commit ddc51c4

Browse files
authored
Support for specifying error types with retry (#1817)
1 parent 940d9fc commit ddc51c4

File tree

4 files changed

+169
-12
lines changed

4 files changed

+169
-12
lines changed

redis/client.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -869,6 +869,7 @@ def __init__(
869869
errors=None,
870870
decode_responses=False,
871871
retry_on_timeout=False,
872+
retry_on_error=[],
872873
ssl=False,
873874
ssl_keyfile=None,
874875
ssl_certfile=None,
@@ -887,8 +888,10 @@ def __init__(
887888
):
888889
"""
889890
Initialize a new Redis client.
890-
To specify a retry policy, first set `retry_on_timeout` to `True`
891-
then set `retry` to a valid `Retry` object
891+
To specify a retry policy for specific errors, first set
892+
`retry_on_error` to a list of the error/s to retry on, then set
893+
`retry` to a valid `Retry` object.
894+
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
892895
"""
893896
if not connection_pool:
894897
if charset is not None:
@@ -905,7 +908,8 @@ def __init__(
905908
)
906909
)
907910
encoding_errors = errors
908-
911+
if retry_on_timeout is True:
912+
retry_on_error.append(TimeoutError)
909913
kwargs = {
910914
"db": db,
911915
"username": username,
@@ -914,7 +918,7 @@ def __init__(
914918
"encoding": encoding,
915919
"encoding_errors": encoding_errors,
916920
"decode_responses": decode_responses,
917-
"retry_on_timeout": retry_on_timeout,
921+
"retry_on_error": retry_on_error,
918922
"retry": copy.deepcopy(retry),
919923
"max_connections": max_connections,
920924
"health_check_interval": health_check_interval,
@@ -1146,11 +1150,14 @@ def _send_command_parse_response(self, conn, command_name, *args, **options):
11461150
def _disconnect_raise(self, conn, error):
11471151
"""
11481152
Close the connection and raise an exception
1149-
if retry_on_timeout is not set or the error
1150-
is not a TimeoutError
1153+
if retry_on_error is not set or the error
1154+
is not one of the specified error types
11511155
"""
11521156
conn.disconnect()
1153-
if not (conn.retry_on_timeout and isinstance(error, TimeoutError)):
1157+
if (
1158+
conn.retry_on_error is None
1159+
or isinstance(error, tuple(conn.retry_on_error)) is False
1160+
):
11541161
raise error
11551162

11561163
# COMMAND EXECUTION AND PROTOCOL PARSING

redis/connection.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,7 @@ def __init__(
513513
socket_keepalive_options=None,
514514
socket_type=0,
515515
retry_on_timeout=False,
516+
retry_on_error=[],
516517
encoding="utf-8",
517518
encoding_errors="strict",
518519
decode_responses=False,
@@ -526,8 +527,10 @@ def __init__(
526527
):
527528
"""
528529
Initialize a new Connection.
529-
To specify a retry policy, first set `retry_on_timeout` to `True`
530-
then set `retry` to a valid `Retry` object
530+
To specify a retry policy for specific errors, first set
531+
`retry_on_error` to a list of the error/s to retry on, then set
532+
`retry` to a valid `Retry` object.
533+
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
531534
"""
532535
self.pid = os.getpid()
533536
self.host = host
@@ -543,11 +546,17 @@ def __init__(
543546
self.socket_type = socket_type
544547
self.retry_on_timeout = retry_on_timeout
545548
if retry_on_timeout:
549+
# Add TimeoutError to the errors list to retry on
550+
retry_on_error.append(TimeoutError)
551+
self.retry_on_error = retry_on_error
552+
if retry_on_error:
546553
if retry is None:
547554
self.retry = Retry(NoBackoff(), 1)
548555
else:
549556
# deep-copy the Retry object as it is mutable
550557
self.retry = copy.deepcopy(retry)
558+
# Update the retry's supported errors with the specified errors
559+
self.retry.update_supported_erros(retry_on_error)
551560
else:
552561
self.retry = Retry(NoBackoff(), 0)
553562
self.health_check_interval = health_check_interval
@@ -969,6 +978,7 @@ def __init__(
969978
encoding_errors="strict",
970979
decode_responses=False,
971980
retry_on_timeout=False,
981+
retry_on_error=[],
972982
parser_class=DefaultParser,
973983
socket_read_size=65536,
974984
health_check_interval=0,
@@ -978,8 +988,10 @@ def __init__(
978988
):
979989
"""
980990
Initialize a new UnixDomainSocketConnection.
981-
To specify a retry policy, first set `retry_on_timeout` to `True`
982-
then set `retry` to a valid `Retry` object
991+
To specify a retry policy for specific errors, first set
992+
`retry_on_error` to a list of the error/s to retry on, then set
993+
`retry` to a valid `Retry` object.
994+
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
983995
"""
984996
self.pid = os.getpid()
985997
self.path = path
@@ -990,11 +1002,17 @@ def __init__(
9901002
self.socket_timeout = socket_timeout
9911003
self.retry_on_timeout = retry_on_timeout
9921004
if retry_on_timeout:
1005+
# Add TimeoutError to the errors list to retry on
1006+
retry_on_error.append(TimeoutError)
1007+
self.retry_on_error = retry_on_error
1008+
if self.retry_on_error:
9931009
if retry is None:
9941010
self.retry = Retry(NoBackoff(), 1)
9951011
else:
9961012
# deep-copy the Retry object as it is mutable
9971013
self.retry = copy.deepcopy(retry)
1014+
# Update the retry's supported errors with the specified errors
1015+
self.retry.update_supported_erros(retry_on_error)
9981016
else:
9991017
self.retry = Retry(NoBackoff(), 0)
10001018
self.health_check_interval = health_check_interval
@@ -1052,6 +1070,7 @@ def to_bool(value):
10521070
"socket_connect_timeout": float,
10531071
"socket_keepalive": to_bool,
10541072
"retry_on_timeout": to_bool,
1073+
"retry_on_error": list,
10551074
"max_connections": int,
10561075
"health_check_interval": int,
10571076
"ssl_check_hostname": to_bool,

redis/retry.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,14 @@ def __init__(
1919
self._retries = retries
2020
self._supported_errors = supported_errors
2121

22+
def update_supported_erros(self, specified_errors: list):
23+
"""
24+
Updates the supported errors with the specified error types
25+
"""
26+
self._supported_errors = tuple(
27+
set(self._supported_errors + tuple(specified_errors))
28+
)
29+
2230
def call_with_retry(self, do, fail):
2331
"""
2432
Execute an operation that might fail and returns its result, or

tests/test_retry.py

Lines changed: 124 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1+
from unittest.mock import patch
2+
13
import pytest
24

35
from redis.backoff import NoBackoff
6+
from redis.client import Redis
47
from redis.connection import Connection, UnixDomainSocketConnection
5-
from redis.exceptions import ConnectionError
8+
from redis.exceptions import (
9+
BusyLoadingError,
10+
ConnectionError,
11+
ReadOnlyError,
12+
TimeoutError,
13+
)
614
from redis.retry import Retry
715

16+
from .conftest import _get_client
17+
818

919
class BackoffMock:
1020
def __init__(self):
@@ -39,6 +49,37 @@ def test_retry_on_timeout_retry(self, Class, retries):
3949
assert isinstance(c.retry, Retry)
4050
assert c.retry._retries == retries
4151

52+
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
53+
def test_retry_on_error(self, Class):
54+
c = Class(retry_on_error=[ReadOnlyError])
55+
assert c.retry_on_error == [ReadOnlyError]
56+
assert isinstance(c.retry, Retry)
57+
assert c.retry._retries == 1
58+
59+
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
60+
def test_retry_on_error_empty_value(self, Class):
61+
c = Class(retry_on_error=[])
62+
assert c.retry_on_error == []
63+
assert isinstance(c.retry, Retry)
64+
assert c.retry._retries == 0
65+
66+
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
67+
def test_retry_on_error_and_timeout(self, Class):
68+
c = Class(
69+
retry_on_error=[ReadOnlyError, BusyLoadingError], retry_on_timeout=True
70+
)
71+
assert c.retry_on_error == [ReadOnlyError, BusyLoadingError, TimeoutError]
72+
assert isinstance(c.retry, Retry)
73+
assert c.retry._retries == 1
74+
75+
@pytest.mark.parametrize("retries", range(10))
76+
@pytest.mark.parametrize("Class", [Connection, UnixDomainSocketConnection])
77+
def test_retry_on_error_retry(self, Class, retries):
78+
c = Class(retry_on_error=[ReadOnlyError], retry=Retry(NoBackoff(), retries))
79+
assert c.retry_on_error == [ReadOnlyError]
80+
assert isinstance(c.retry, Retry)
81+
assert c.retry._retries == retries
82+
4283

4384
class TestRetry:
4485
"Test that Retry calls backoff and retries the expected number of times"
@@ -65,3 +106,85 @@ def test_retry(self, retries):
65106
assert self.actual_failures == 1 + retries
66107
assert backoff.reset_calls == 1
67108
assert backoff.calls == retries
109+
110+
111+
@pytest.mark.onlynoncluster
112+
class TestRedisClientRetry:
113+
"Test the standalone Redis client behavior with retries"
114+
115+
def test_client_retry_on_error_with_success(self, request):
116+
with patch.object(Redis, "parse_response") as parse_response:
117+
118+
def mock_parse_response(connection, *args, **options):
119+
def ok_response(connection, *args, **options):
120+
return "MOCK_OK"
121+
122+
parse_response.side_effect = ok_response
123+
raise ReadOnlyError()
124+
125+
parse_response.side_effect = mock_parse_response
126+
r = _get_client(Redis, request, retry_on_error=[ReadOnlyError])
127+
assert r.get("foo") == "MOCK_OK"
128+
assert parse_response.call_count == 2
129+
130+
def test_client_retry_on_error_raise(self, request):
131+
with patch.object(Redis, "parse_response") as parse_response:
132+
parse_response.side_effect = BusyLoadingError()
133+
retries = 3
134+
r = _get_client(
135+
Redis,
136+
request,
137+
retry_on_error=[ReadOnlyError, BusyLoadingError],
138+
retry=Retry(NoBackoff(), retries),
139+
)
140+
with pytest.raises(BusyLoadingError):
141+
try:
142+
r.get("foo")
143+
finally:
144+
assert parse_response.call_count == retries + 1
145+
146+
def test_client_retry_on_error_different_error_raised(self, request):
147+
with patch.object(Redis, "parse_response") as parse_response:
148+
parse_response.side_effect = TimeoutError()
149+
retries = 3
150+
r = _get_client(
151+
Redis,
152+
request,
153+
retry_on_error=[ReadOnlyError],
154+
retry=Retry(NoBackoff(), retries),
155+
)
156+
with pytest.raises(TimeoutError):
157+
try:
158+
r.get("foo")
159+
finally:
160+
assert parse_response.call_count == 1
161+
162+
def test_client_retry_on_error_and_timeout(self, request):
163+
with patch.object(Redis, "parse_response") as parse_response:
164+
parse_response.side_effect = TimeoutError()
165+
retries = 3
166+
r = _get_client(
167+
Redis,
168+
request,
169+
retry_on_error=[ReadOnlyError],
170+
retry_on_timeout=True,
171+
retry=Retry(NoBackoff(), retries),
172+
)
173+
with pytest.raises(TimeoutError):
174+
try:
175+
r.get("foo")
176+
finally:
177+
assert parse_response.call_count == retries + 1
178+
179+
def test_client_retry_on_timeout(self, request):
180+
with patch.object(Redis, "parse_response") as parse_response:
181+
parse_response.side_effect = TimeoutError()
182+
retries = 3
183+
r = _get_client(
184+
Redis, request, retry_on_timeout=True, retry=Retry(NoBackoff(), retries)
185+
)
186+
with pytest.raises(TimeoutError):
187+
try:
188+
r.get("foo")
189+
finally:
190+
assert parse_response.call_count == retries + 1

0 commit comments

Comments
 (0)