Skip to content

Commit b69fa00

Browse files
committed
Added a replacement for the default cluster node in the event of failure. Handles failovers better.
1 parent fa45fb1 commit b69fa00

File tree

5 files changed

+120
-20
lines changed

5 files changed

+120
-20
lines changed

CHANGES

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
* Fixed "cannot pickle '_thread.lock' object" bug (#2354, #2297)
2929
* Added CredentialsProvider class to support password rotation
3030
* Enable Lock for asyncio cluster mode
31+
* Added a replacement for the default cluster node in the event of failure (#2463)
3132

3233
* 4.1.3 (Feb 8, 2022)
3334
* Fix flushdb and flushall (#1926)

redis/asyncio/cluster.py

+26-9
Original file line numberDiff line numberDiff line change
@@ -516,35 +516,44 @@ def set_response_callback(self, command: str, callback: ResponseCallbackT) -> No
516516

517517
async def _determine_nodes(
518518
self, command: str, *args: Any, node_flag: Optional[str] = None
519-
) -> List["ClusterNode"]:
519+
) -> tuple[list["ClusterNode"], bool]:
520+
"""Determine which nodes should be executed the command on
521+
522+
Returns:
523+
tuple[list[Type[ClusterNode]], bool]:
524+
A tuple containing a list of target nodes and a bool indicating
525+
if the return node was chosen because it is the default node
526+
"""
520527
if not node_flag:
521528
# get the nodes group for this command if it was predefined
522529
node_flag = self.command_flags.get(command)
523530

524531
if node_flag in self.node_flags:
525532
if node_flag == self.__class__.DEFAULT_NODE:
526533
# return the cluster's default node
527-
return [self.nodes_manager.default_node]
534+
return [self.nodes_manager.default_node], True
528535
if node_flag == self.__class__.PRIMARIES:
529536
# return all primaries
530-
return self.nodes_manager.get_nodes_by_server_type(PRIMARY)
537+
return self.nodes_manager.get_nodes_by_server_type(PRIMARY), False
531538
if node_flag == self.__class__.REPLICAS:
532539
# return all replicas
533-
return self.nodes_manager.get_nodes_by_server_type(REPLICA)
540+
return self.nodes_manager.get_nodes_by_server_type(REPLICA), False
534541
if node_flag == self.__class__.ALL_NODES:
535542
# return all nodes
536-
return list(self.nodes_manager.nodes_cache.values())
543+
return list(self.nodes_manager.nodes_cache.values()), False
537544
if node_flag == self.__class__.RANDOM:
538545
# return a random node
539-
return [random.choice(list(self.nodes_manager.nodes_cache.values()))]
546+
return [
547+
random.choice(list(self.nodes_manager.nodes_cache.values()))
548+
], False
540549

541550
# get the node that holds the key's slot
542551
return [
543552
self.nodes_manager.get_node_from_slot(
544553
await self._determine_slot(command, *args),
545554
self.read_from_replicas and command in READ_COMMANDS,
546555
)
547-
]
556+
], False
548557

549558
async def _determine_slot(self, command: str, *args: Any) -> int:
550559
if self.command_flags.get(command) == SLOT_ID:
@@ -641,6 +650,7 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
641650
command = args[0]
642651
target_nodes = []
643652
target_nodes_specified = False
653+
is_default_node = False
644654
retry_attempts = self.cluster_error_retry_attempts
645655

646656
passed_targets = kwargs.pop("target_nodes", None)
@@ -654,10 +664,13 @@ async def execute_command(self, *args: EncodableT, **kwargs: Any) -> Any:
654664
for _ in range(execute_attempts):
655665
if self._initialize:
656666
await self.initialize()
667+
if is_default_node:
668+
# Replace the default cluster node
669+
self.replace_default_node()
657670
try:
658671
if not target_nodes_specified:
659672
# Determine the nodes to execute the command on
660-
target_nodes = await self._determine_nodes(
673+
target_nodes, is_default_node = await self._determine_nodes(
661674
*args, node_flag=passed_targets
662675
)
663676
if not target_nodes:
@@ -1436,12 +1449,13 @@ async def _execute(
14361449
]
14371450

14381451
nodes = {}
1452+
is_default_node = False
14391453
for cmd in todo:
14401454
passed_targets = cmd.kwargs.pop("target_nodes", None)
14411455
if passed_targets and not client._is_node_flag(passed_targets):
14421456
target_nodes = client._parse_target_nodes(passed_targets)
14431457
else:
1444-
target_nodes = await client._determine_nodes(
1458+
target_nodes, is_default_node = await client._determine_nodes(
14451459
*cmd.args, node_flag=passed_targets
14461460
)
14471461
if not target_nodes:
@@ -1487,6 +1501,9 @@ async def _execute(
14871501
result.args = (msg,) + result.args[1:]
14881502
raise result
14891503

1504+
if is_default_node:
1505+
self.replace_default_node()
1506+
14901507
return [cmd.result for cmd in stack]
14911508

14921509
def _split_command_across_slots(

redis/cluster.py

+50-11
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,30 @@ class AbstractRedisCluster:
379379

380380
ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, ClusterDownError)
381381

382+
def replace_default_node(self, target_node: "ClusterNode" = None) -> None:
383+
"""Replace the default cluster node.
384+
A random cluster node will be chosen if target_node isn't passed, and primaries
385+
will be prioritized. The default node will not be changed if there are no other
386+
nodes in the cluster.
387+
388+
Args:
389+
target_node (ClusterNode, optional): Target node to replace the default
390+
node. Defaults to None.
391+
"""
392+
if target_node:
393+
self.nodes_manager.default_node = target_node
394+
else:
395+
curr_node = self.get_default_node()
396+
primaries = [node for node in self.get_primaries() if node != curr_node]
397+
if primaries:
398+
# Choose a primary if the cluster contains different primaries
399+
self.nodes_manager.default_node = random.choice(primaries)
400+
else:
401+
# Otherwise, hoose a primary if the cluster contains different primaries
402+
replicas = [node for node in self.get_replicas() if node != curr_node]
403+
if replicas:
404+
self.nodes_manager.default_node = random.choice(replicas)
405+
382406

383407
class RedisCluster(AbstractRedisCluster, RedisClusterCommands):
384408
@classmethod
@@ -811,7 +835,14 @@ def set_response_callback(self, command, callback):
811835
"""Set a custom Response Callback"""
812836
self.cluster_response_callbacks[command] = callback
813837

814-
def _determine_nodes(self, *args, **kwargs):
838+
def _determine_nodes(self, *args, **kwargs) -> tuple[list["ClusterNode"], bool]:
839+
"""Determine which nodes should be executed the command on
840+
841+
Returns:
842+
tuple[list[Type[ClusterNode]], bool]:
843+
A tuple containing a list of target nodes and a bool indicating
844+
if the return node was chosen because it is the default node
845+
"""
815846
command = args[0].upper()
816847
if len(args) >= 2 and f"{args[0]} {args[1]}".upper() in self.command_flags:
817848
command = f"{args[0]} {args[1]}".upper()
@@ -825,28 +856,28 @@ def _determine_nodes(self, *args, **kwargs):
825856
command_flag = self.command_flags.get(command)
826857
if command_flag == self.__class__.RANDOM:
827858
# return a random node
828-
return [self.get_random_node()]
859+
return [self.get_random_node()], False
829860
elif command_flag == self.__class__.PRIMARIES:
830861
# return all primaries
831-
return self.get_primaries()
862+
return self.get_primaries(), False
832863
elif command_flag == self.__class__.REPLICAS:
833864
# return all replicas
834-
return self.get_replicas()
865+
return self.get_replicas(), False
835866
elif command_flag == self.__class__.ALL_NODES:
836867
# return all nodes
837-
return self.get_nodes()
868+
return self.get_nodes(), False
838869
elif command_flag == self.__class__.DEFAULT_NODE:
839870
# return the cluster's default node
840-
return [self.nodes_manager.default_node]
871+
return [self.nodes_manager.default_node], True
841872
elif command in self.__class__.SEARCH_COMMANDS[0]:
842-
return [self.nodes_manager.default_node]
873+
return [self.nodes_manager.default_node], True
843874
else:
844875
# get the node that holds the key's slot
845876
slot = self.determine_slot(*args)
846877
node = self.nodes_manager.get_node_from_slot(
847878
slot, self.read_from_replicas and command in READ_COMMANDS
848879
)
849-
return [node]
880+
return [node], False
850881

851882
def _should_reinitialized(self):
852883
# To reinitialize the cluster on every MOVED error,
@@ -990,6 +1021,7 @@ def execute_command(self, *args, **kwargs):
9901021
dict<Any, ClusterNode>
9911022
"""
9921023
target_nodes_specified = False
1024+
is_default_node = False
9931025
target_nodes = None
9941026
passed_targets = kwargs.pop("target_nodes", None)
9951027
if passed_targets is not None and not self._is_nodes_flag(passed_targets):
@@ -1013,7 +1045,7 @@ def execute_command(self, *args, **kwargs):
10131045
res = {}
10141046
if not target_nodes_specified:
10151047
# Determine the nodes to execute the command on
1016-
target_nodes = self._determine_nodes(
1048+
target_nodes, is_default_node = self._determine_nodes(
10171049
*args, **kwargs, nodes_flag=passed_targets
10181050
)
10191051
if not target_nodes:
@@ -1025,6 +1057,9 @@ def execute_command(self, *args, **kwargs):
10251057
# Return the processed result
10261058
return self._process_result(args[0], res, **kwargs)
10271059
except Exception as e:
1060+
if is_default_node:
1061+
# Replace the default cluster node
1062+
self.replace_default_node()
10281063
if retry_attempts > 0 and type(e) in self.__class__.ERRORS_ALLOW_RETRY:
10291064
# The nodes and slots cache were reinitialized.
10301065
# Try again with the new cluster setup.
@@ -1883,7 +1918,7 @@ def _send_cluster_commands(
18831918
# if we have to run through it again, we only retry
18841919
# the commands that failed.
18851920
attempt = sorted(stack, key=lambda x: x.position)
1886-
1921+
is_default_node = False
18871922
# build a list of node objects based on node names we need to
18881923
nodes = {}
18891924

@@ -1900,7 +1935,7 @@ def _send_cluster_commands(
19001935
if passed_targets and not self._is_nodes_flag(passed_targets):
19011936
target_nodes = self._parse_target_nodes(passed_targets)
19021937
else:
1903-
target_nodes = self._determine_nodes(
1938+
target_nodes, is_default_node = self._determine_nodes(
19041939
*c.args, node_flag=passed_targets
19051940
)
19061941
if not target_nodes:
@@ -1926,6 +1961,8 @@ def _send_cluster_commands(
19261961
# Connection retries are being handled in the node's
19271962
# Retry object. Reinitialize the node -> slot table.
19281963
self.nodes_manager.initialize()
1964+
if is_default_node:
1965+
self.replace_default_node()
19291966
raise
19301967
nodes[node_name] = NodeCommands(
19311968
redis_node.parse_response,
@@ -2007,6 +2044,8 @@ def _send_cluster_commands(
20072044
self.reinitialize_counter += 1
20082045
if self._should_reinitialized():
20092046
self.nodes_manager.initialize()
2047+
if is_default_node:
2048+
self.replace_default_node()
20102049
for c in attempt:
20112050
try:
20122051
# send each command individually like we

tests/test_asyncio/test_cluster.py

+20
Original file line numberDiff line numberDiff line change
@@ -788,6 +788,26 @@ async def test_can_run_concurrent_commands(self, request: FixtureRequest) -> Non
788788
)
789789
await rc.close()
790790

791+
def test_replace_cluster_node(self, r: RedisCluster) -> None:
792+
prev_default_node = r.get_default_node()
793+
r.replace_default_node()
794+
assert r.get_default_node() != prev_default_node
795+
r.replace_default_node(prev_default_node)
796+
assert r.get_default_node() == prev_default_node
797+
798+
async def test_default_node_is_replaced_after_exception(self, r):
799+
curr_default_node = r.get_default_node()
800+
# CLUSTER NODES command is being executed on the default node
801+
nodes = await r.cluster_nodes()
802+
assert "myself" in nodes.get(curr_default_node.name).get("flags")
803+
804+
# Mock connection error for the default node
805+
mock_node_resp_exc(curr_default_node, ConnectionError("error"))
806+
# Test that the command succeed from a different node
807+
nodes = await r.cluster_nodes()
808+
assert "myself" not in nodes.get(curr_default_node.name).get("flags")
809+
assert r.get_default_node() != curr_default_node
810+
791811

792812
class TestClusterRedisCommands:
793813
"""

tests/test_cluster.py

+23
Original file line numberDiff line numberDiff line change
@@ -791,6 +791,29 @@ def test_cluster_retry_object(self, r) -> None:
791791
== retry._retries
792792
)
793793

794+
def test_replace_cluster_node(self, r) -> None:
795+
prev_default_node = r.get_default_node()
796+
r.replace_default_node()
797+
assert r.get_default_node() != prev_default_node
798+
r.replace_default_node(prev_default_node)
799+
assert r.get_default_node() == prev_default_node
800+
801+
def test_default_node_is_replaced_after_exception(self, r):
802+
curr_default_node = r.get_default_node()
803+
# CLUSTER NODES command is being executed on the default node
804+
nodes = r.cluster_nodes()
805+
assert "myself" in nodes.get(curr_default_node.name).get("flags")
806+
807+
def raise_connection_error():
808+
raise ConnectionError("error")
809+
810+
# Mock connection error for the default node
811+
mock_node_resp_func(curr_default_node, raise_connection_error)
812+
# Test that the command succeed from a different node
813+
nodes = r.cluster_nodes()
814+
assert "myself" not in nodes.get(curr_default_node.name).get("flags")
815+
assert r.get_default_node() != curr_default_node
816+
794817

795818
@pytest.mark.onlycluster
796819
class TestClusterRedisCommands:

0 commit comments

Comments
 (0)