From 9c2e299e83ae28619c3b54bf338095e2abadb0a3 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Thu, 28 Oct 2021 10:49:07 +0300 Subject: [PATCH 01/22] Added RedisCluster client to support Redis Cluster mode --- README.md | 224 +++- docker/base/Dockerfile.cluster | 8 + docker/base/create_cluster.sh | 21 + docker/cluster/redis.conf | 3 + redis/__init__.py | 2 + redis/client.py | 40 +- redis/cluster.py | 1883 ++++++++++++++++++++++++++++++++ redis/commands/__init__.py | 4 + redis/commands/cluster.py | 801 ++++++++++++++ redis/commands/core.py | 96 +- redis/commands/parser.py | 108 ++ redis/connection.py | 24 +- redis/crc.py | 28 + redis/exceptions.py | 64 ++ redis/utils.py | 36 + tasks.py | 3 + tests/conftest.py | 68 +- tests/test_cluster.py | 1403 ++++++++++++++++++++++++ tests/test_commands.py | 163 ++- tests/test_connection.py | 5 +- tests/test_connection_pool.py | 23 +- tests/test_encoding.py | 1 + tests/test_json.py | 427 ++++---- tests/test_lock.py | 4 +- tests/test_monitor.py | 3 +- tests/test_multiprocessing.py | 19 +- tests/test_pipeline.py | 17 +- tests/test_pubsub.py | 14 +- tests/test_scripting.py | 5 +- tests/test_sentinel.py | 299 +++-- tox.ini | 17 + 31 files changed, 5318 insertions(+), 495 deletions(-) create mode 100644 docker/base/Dockerfile.cluster create mode 100644 docker/base/create_cluster.sh create mode 100644 docker/cluster/redis.conf create mode 100644 redis/cluster.py create mode 100644 redis/commands/cluster.py create mode 100644 redis/commands/parser.py create mode 100644 redis/crc.py create mode 100644 tests/test_cluster.py diff --git a/README.md b/README.md index 4bcea7be33..5cd9b53570 100644 --- a/README.md +++ b/README.md @@ -942,8 +942,228 @@ C 3 ### Cluster Mode -redis-py does not currently support [Cluster -Mode](https://redis.io/topics/cluster-tutorial). +redis-py is now supports cluster mode and provides a client for +[Redis Cluster](). + +The cluster client is based on [redis-py-cluster](https://github.com/Grokzen/redis-py-cluster) +by Grokzen, with a lot of added and +changed functionality. + +**Create RedisCluster:** + +Connecting redis-py to the Redis Cluster instance(s) is easy. +RedisCluster requires at least one node to discover the whole cluster nodes, +and there is multiple ways of creating a RedisCluster instance: + +- Use the 'host' and 'port' arguments: + +``` pycon + >>> from redis.cluster import RedisCluster as Redis + >>> rc = Redis(host='localhost', port=6379) + >>> print(rc.get_nodes()) + [[host=127.0.0.1,port=6379,name=127.0.0.1:6379,server_type=primary,redis_connection=Redis>>], [host=127.0.0.1,port=6378,name=127.0.0.1:6378,server_type=primary,redis_connection=Redis>>], [host=127.0.0.1,port=6377,name=127.0.0.1:6377,server_type=replica,redis_connection=Redis>>]] +``` +- Use Redis URL: + +``` pycon + >>> from redis.cluster import RedisCluster as Redis + >>> rc = Redis.from_url("redis://localhost:6379/0") +``` + +- Use ClusterNode(s): + +``` pycon + >>> from redis.cluster import RedisCluster as Redis + >>> from redis.cluster import ClusterNode + >>> nodes = [ClusterNode('localhost', 6379), ClusterNode('localhost', 6378)] + >>> rc = Redis(startup_nodes=nodes) +``` + +When a RedisCluster instance is being created it first attempts to establish a +connection to one of the provided startup nodes. If none of the startup nodes +are reachable, a 'RedisClusterException' will be thrown. +After a connection to the one of the cluster's nodes is established, the +RedisCluster instance will be initialized with 3 caches: +a slots cache which maps each of the 16384 slots to the node/s handling them, +a nodes cache that contains ClusterNode objects (name, host, port, redis connection) +for all of the cluster's nodes, and a commands cache contains all the server +supported commands that were retrieved using the Redis 'COMMAND' output. + +RedisCluster instance can be directly used to execute Redis commands. When a +command is being executed through the cluster instance, the target node(s) will +be internally determined. When using a key-based command, the target node will +be the node that holds the key's slot. +Cluster management commands or other cluster commands have predefined node +group targets (all-primaries, all-nodes, random-node, all-replicas), which are +outlined in the command’s function documentation. +For example, ‘KEYS’ command will be sent to all primaries and return all keys +in the cluster, and ‘CLUSTER NODES’ command will be sent to a random node. +Other management commands will require you to pass the target node/s to execute +the command on. + +``` pycon + >>> # target-nodes: the node that holds 'foo1's key slot + >>> rc.set('foo1', 'bar1') + >>> # target-nodes: the node that holds 'foo2's key slot + >>> rc.set('foo2', 'bar2') + >>> # target-nodes: the node that holds 'foo1's key slot + >>> print(rc.get('foo1')) + b'bar' + >>> # target-nodes: all-primaries + >>> print(rc.keys()) + [b'foo1', b'foo2'] + >>> # target-nodes: all-nodes + >>> rc.flushall() +``` + +**Specifying Target Nodes:** + +As mentioned above, some RedisCluster commands will require you to provide the +target node/s that you want to execute the command on, and in other cases, the +target node will be determined by the client itself. That being said, ALL +RedisCluster commands can be executed against a specific node or a group of +nodes by passing the command kwarg `target_nodes`. +The best practice is to specify target nodes using RedisCluster class's node +flags: PRIMARIES, REPLICAS, ALL_NODES, RANDOM. When a nodes flag is passed +along with a command, it will be internally resolved to the relevant node/s. +If the nodes topology of the cluster changes during the execution of a command, +the client will be able to resolve the nodes flag again with the new topology +and attempt to retry executing the command. + +``` pycon + >>> from redis.cluster import RedisCluster as Redis + >>> # run cluster-meet command on all of the cluster's nodes + >>> rc.cluster_meet(Redis.ALL_NODES, '127.0.0.1', 6379) + >>> # ping all replicas + >>> rc.ping(Redis.REPLICAS) + >>> # ping a specific node + >>> rc.ping(Redis.RANDOM) + >>> # ping all nodes in the cluster, default command behavior + >>> rc.ping() + >>> # execute bgsave in all primaries + >>> rc.bgsave(Redis.PRIMARIES) +``` + +You could also pass ClusterNodes directly if you want to execute a command on a +specific node / node group that isn't addressed by the nodes flag. However, if +the command execution fails due to cluster topology changes, a retry attempt +will not be made, since the passed target node/s may no longer be valid, and +the relevant cluster or connection error will be returned. + +``` pycon + >>> node = rc.get_node('localhost', 6379) + >>> # Get the keys only for that specific node + >>> rc.keys(node) + >>> # get Redis info from a subset of primaries + >>> subset_primaries = [node for node in rc.get_primaries() if node.port > 6378] + >>> rc.info(subset_primaries) +``` + +In addition, you can use the RedisCluster instance to obtain the Redis instance +of a specific node and execute commands on that node directly. The Redis client, +however, cannot handle cluster failures and retries. + +``` pycon + >>> cluster_node = rc.get_node(host='localhost', port=6379) + >>> print(cluster_node) + [host=127.0.0.1,port=6379,name=127.0.0.1:6379,server_type=primary,redis_connection=Redis>>] + >>> r = cluster_node.redis_connection + >>> r.client_list() + [{'id': '276', 'addr': '127.0.0.1:64108', 'fd': '16', 'name': '', 'age': '0', 'idle': '0', 'flags': 'N', 'db': '0', 'sub': '0', 'psub': '0', 'multi': '-1', 'qbuf': '26', 'qbuf-free': '32742', 'argv-mem': '10', 'obl': '0', 'oll': '0', 'omem': '0', 'tot-mem': '54298', 'events': 'r', 'cmd': 'client', 'user': 'default'}] + >>> # Get the keys only for that specific node + >>> r.keys() + [b'foo1'] +``` + +**Multi-key commands:** + +Redis supports multi-key commands in Cluster Mode, such as Set type unions or +intersections, mset and mget, as long as the keys all hash to the same slot. +By using RedisCluster client, you can use the known functions (e.g. mget, mset) +to perform an atomic multi-key operation. However, you must ensure all keys are +mapped to the same slot, otherwise a RedisClusterException will be thrown. +Redis Cluster implements a concept called hash tags that can be used in order +to force certain keys to be stored in the same hash slot, see +[Keys hash tag](https://redis.io/topics/cluster-spec#keys-hash-tags). +You can also use nonatomic for some of the multikey operations, and pass keys +that aren't mapped to the same slot. The client will then map the keys to the +relevant slots, sending the commands to the slots' node owners. Non-atomic +operations batch the keys according to their hash value, and then each batch is +sent separately to the slot's owner. + +``` pycon + # Atomic operations can be used when all keys are mapped to the same slot + >>> rc.mset({'{foo}1': 'bar1', '{foo}2': 'bar2'}) + >>> rc.mget('{foo}1', '{foo}2') + [b'bar1', b'bar2'] + # Non-atomic multi-key operations splits the keys into different slots + >>> rc.mset_nonatomic({'foo': 'value1', 'bar': 'value2', 'zzz': 'value3') + >>> rc.mget_nonatomic('foo', 'bar', 'zzz') + [b'value1', b'value2', b'value3'] +``` + +**Cluster PubSub:** + +When a ClusterPubSub instance is created without specifying a node, a single +node will be transparently chosen for the pubsub connection on the +first command execution. The node will be determined by: + 1. Hashing the channel name in the request to find its keyslot + 2. Selecting a node that handles the keyslot: If read_from_replicas is + set to true, a replica can be selected. + +*Known limitations with pubsub:* + +Pattern subscribe and publish do not work properly because if we hash a pattern +like fo* we will get a keyslot for that string but there is a endless +possibilities of channel names based on that pattern that we can’t know in +advance. This feature is not limited but the commands is not recommended to use +right now. +See [redis-py-cluster documentaion](https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html) + for more. + +``` pycon + >>> p1 = rc.pubsub() + # p1 connection will be set to the node that holds 'foo' keyslot + >>> p1.subscribe('foo') + # p2 connection will be set to node 'localhost:6379' + >>> p2 = rc.pubsub(rc.get_node('localhost', 6379)) +``` + +**Read Only Mode** + +By default, Redis Cluster always returns MOVE redirection response on accessing +a replica node. You can overcome this limitation and scale read commands with +READONLY mode. + +To enable READONLY mode pass read_from_replicas=True to RedisCluster +constructor. When set to true, read commands will be assigned between the +primary and its replications in a Round-Robin manner. + +You could also enable READONLY mode in runtime by running readonly() method, +or disable it with readwrite(). + +``` pycon + >>> from cluster import RedisCluster as Redis + # Use 'debug' mode to print the node that the command is executed on + >>> rc_readonly = Redis(startup_nodes=startup_nodes, + read_from_replicas=True, debug=True) + >>> rc_readonly.set('{foo}1', 'bar1') + >>> for i in range(0, 4): + # Assigns read command to the slot's hosts in a Round-Robin manner + >>> rc_readonly.get('{foo}1') + # set command would be directed only to the slot's primary node + >>> rc_readonly.set('{foo}2', 'bar2') + # reset READONLY flag + >>> rc_readonly.readwrite() + # now the get command would be directed only to the slot's primary node + >>> rc_readonly.get('{foo}1') +``` + + + +See [Redis Cluster tutorial](https://redis.io/topics/cluster-tutorial) and +[Redis Cluster specifications](https://redis.io/topics/cluster-spec) +to learn more about Redis Cluster. ### Author diff --git a/docker/base/Dockerfile.cluster b/docker/base/Dockerfile.cluster new file mode 100644 index 0000000000..70e8013631 --- /dev/null +++ b/docker/base/Dockerfile.cluster @@ -0,0 +1,8 @@ +FROM redis:6.2.6-buster + +COPY create_cluster.sh /create_cluster.sh +RUN chmod +x /create_cluster.sh + +EXPOSE 16379 16380 16381 16382 16383 16384 + +CMD [ "/create_cluster.sh"] \ No newline at end of file diff --git a/docker/base/create_cluster.sh b/docker/base/create_cluster.sh new file mode 100644 index 0000000000..d339f1836a --- /dev/null +++ b/docker/base/create_cluster.sh @@ -0,0 +1,21 @@ +#! /bin/bash +mkdir -p /nodes +echo -n > /nodes/nodemap +for PORT in $(seq 16379 16384); do + mkdir -p /nodes/$PORT + if [[ -e /redis.conf ]]; then + cp /redis.conf /nodes/$PORT/redis.conf + else + touch /nodes/$PORT/redis.conf + fi + cat << EOF >> /nodes/$PORT/redis.conf +port $PORT +daemonize yes +logfile /redis.log +dir /nodes/$PORT +EOF + redis-server /nodes/$PORT/redis.conf + echo 127.0.0.1:$PORT >> /nodes/nodemap +done +echo yes | redis-cli --cluster create $(seq -f 127.0.0.1:%g 16379 16384) --cluster-replicas 1 +tail -f /redis.log diff --git a/docker/cluster/redis.conf b/docker/cluster/redis.conf new file mode 100644 index 0000000000..cc22e16ffe --- /dev/null +++ b/docker/cluster/redis.conf @@ -0,0 +1,3 @@ +# Redis Cluster config file will be shared across all nodes. +# Dont pass node-unique arguments (e.g. port, dir). +cluster-enabled yes diff --git a/redis/__init__.py b/redis/__init__.py index 2458b5bc49..47adaa8c83 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -1,4 +1,5 @@ from redis.client import Redis, StrictRedis +from redis.cluster import RedisCluster from redis.connection import ( BlockingConnectionPool, ConnectionPool, @@ -49,6 +50,7 @@ def int_or_str(value): 'PubSubError', 'ReadOnlyError', 'Redis', + 'RedisCluster', 'RedisError', 'ResponseError', 'SSLConnection', diff --git a/redis/client.py b/redis/client.py index 986af7cfba..93979569b1 100755 --- a/redis/client.py +++ b/redis/client.py @@ -460,6 +460,7 @@ def _parse_node_line(line): line_items = line.split(' ') node_id, addr, flags, master_id, ping, pong, epoch, \ connected = line.split(' ')[:8] + addr = addr.split('@')[0] slots = [sl.split('-') for sl in line_items[8:]] node_dict = { 'node_id': node_id, @@ -475,8 +476,13 @@ def _parse_node_line(line): def parse_cluster_nodes(response, **options): - raw_lines = str_if_bytes(response).splitlines() - return dict(_parse_node_line(line) for line in raw_lines) + """ + @see: http://redis.io/commands/cluster-nodes # string + @see: http://redis.io/commands/cluster-replicas # list of string + """ + if isinstance(response, str): + response = response.splitlines() + return dict(_parse_node_line(str_if_bytes(node)) for node in response) def parse_geosearch_generic(response, **options): @@ -515,6 +521,21 @@ def parse_geosearch_generic(response, **options): ] +def parse_command(response, **options): + commands = {} + for command in response: + cmd_dict = {} + cmd_name = str_if_bytes(command[0]) + cmd_dict['name'] = cmd_name + cmd_dict['arity'] = str_if_bytes(command[1]) + cmd_dict['flags'] = [str_if_bytes(flag) for flag in command[2]] + cmd_dict['first_key_pos'] = command[3] + cmd_dict['last_key_pos'] = command[4] + cmd_dict['step_count'] = command[5] + commands[cmd_name] = cmd_dict + return commands + + def parse_pubsub_numsub(response, **options): return list(zip(response[0::2], response[1::2])) @@ -700,8 +721,10 @@ class Redis(RedisModuleCommands, CoreCommands, object): 'CLUSTER SET-CONFIG-EPOCH': bool_ok, 'CLUSTER SETSLOT': bool_ok, 'CLUSTER SLAVES': parse_cluster_nodes, - 'COMMAND': int, + 'CLUSTER REPLICAS': parse_cluster_nodes, + 'COMMAND': parse_command, 'COMMAND COUNT': int, + 'COMMAND GETKEYS': lambda r: list(map(str_if_bytes, r)), 'CONFIG GET': parse_config_get, 'CONFIG RESETSTAT': bool_ok, 'CONFIG SET': bool_ok, @@ -824,7 +847,7 @@ def __init__(self, host='localhost', port=6379, ssl_check_hostname=False, max_connections=None, single_connection_client=False, health_check_interval=0, client_name=None, username=None, - retry=None): + retry=None, redis_connect_func=None): """ Initialize a new Redis client. To specify a retry policy, first set `retry_on_timeout` to `True` @@ -852,7 +875,8 @@ def __init__(self, host='localhost', port=6379, 'retry': copy.deepcopy(retry), 'max_connections': max_connections, 'health_check_interval': health_check_interval, - 'client_name': client_name + 'client_name': client_name, + 'redis_connect_func': redis_connect_func } # based on input, setup appropriate connection args if unix_socket_path is not None: @@ -1188,14 +1212,16 @@ class PubSub: HEALTH_CHECK_MESSAGE = 'redis-py-health-check' def __init__(self, connection_pool, shard_hint=None, - ignore_subscribe_messages=False): + ignore_subscribe_messages=False, encoder=None): self.connection_pool = connection_pool self.shard_hint = shard_hint self.ignore_subscribe_messages = ignore_subscribe_messages self.connection = None # we need to know the encoding options for this connection in order # to lookup channel and pattern names for callback handlers. - self.encoder = self.connection_pool.get_encoder() + self.encoder = encoder + if self.encoder is None: + self.encoder = self.connection_pool.get_encoder() if self.encoder.decode_responses: self.health_check_response = ['pong', self.HEALTH_CHECK_MESSAGE] else: diff --git a/redis/cluster.py b/redis/cluster.py new file mode 100644 index 0000000000..e3976dcb03 --- /dev/null +++ b/redis/cluster.py @@ -0,0 +1,1883 @@ +import copy +import random +import socket +import time +import threading +import warnings +import sys + +from collections import OrderedDict +from redis.client import CaseInsensitiveDict, Redis, PubSub +from redis.commands import ( + ClusterCommands, + CommandsParser +) +from redis.connection import DefaultParser, ConnectionPool, Encoder, parse_url +from redis.crc import key_slot, REDIS_CLUSTER_HASH_SLOTS +from redis.exceptions import ( + AskError, + BusyLoadingError, + ClusterCrossSlotError, + ClusterDownError, + ClusterError, + DataError, + MasterDownError, + MovedError, + RedisClusterException, + RedisError, + ResponseError, + SlotNotCoveredError, + TimeoutError, + TryAgainError, +) +from redis.utils import ( + dict_merge, + list_keys_to_dict, + merge_result, + str_if_bytes, + safe_str +) + + +def get_node_name(host, port): + return '{0}:{1}'.format(host, port) + + +def get_connection(redis_node, *args, **options): + return redis_node.connection or redis_node.connection_pool.get_connection( + args[0], **options + ) + + +def parse_pubsub_numsub(command, res, **options): + numsub_d = OrderedDict() + for numsub_tups in res.values(): + for channel, numsubbed in numsub_tups: + try: + numsub_d[channel] += numsubbed + except KeyError: + numsub_d[channel] = numsubbed + + ret_numsub = [ + (channel, numsub) + for channel, numsub in numsub_d.items() + ] + return ret_numsub + + +def parse_cluster_slots(resp, **options): + current_host = options.get('current_host', '') + + def fix_server(*args): + return str_if_bytes(args[0]) or current_host, args[1] + + slots = {} + for slot in resp: + start, end, primary = slot[:3] + replicas = slot[3:] + slots[start, end] = { + 'primary': fix_server(*primary), + 'replicas': [fix_server(*replica) for replica in replicas], + } + + return slots + + +PRIMARY = "primary" +REPLICA = "replica" +SLOT_ID = 'slot-id' + +REDIS_ALLOWED_KEYS = ( + "charset", + "connection_class", + "connection_pool", + "db", + "decode_responses", + "encoding", + "encoding_errors", + "errors", + "host", + "max_connections", + "nodes_flag", + "redis_connect_func", + "password", + "port", + "retry_on_timeout", + "socket_connect_timeout", + "socket_keepalive", + "socket_keepalive_options", + "socket_timeout", + "ssl", + "ssl_ca_certs", + "ssl_certfile", + "ssl_cert_reqs", + "ssl_keyfile", + "unix_socket_path", + "username", +) +KWARGS_DISABLED_KEYS = ( + "host", + "port", +) + +# Not complete, but covers the major ones +# https://redis.io/commands +READ_COMMANDS = frozenset([ + "BITCOUNT", + "BITPOS", + "EXISTS", + "GEODIST", + "GEOHASH", + "GEOPOS", + "GEORADIUS", + "GEORADIUSBYMEMBER", + "GET", + "GETBIT", + "GETRANGE", + "HEXISTS", + "HGET", + "HGETALL", + "HKEYS", + "HLEN", + "HMGET", + "HSTRLEN", + "HVALS", + "KEYS", + "LINDEX", + "LLEN", + "LRANGE", + "MGET", + "PTTL", + "RANDOMKEY", + "SCARD", + "SDIFF", + "SINTER", + "SISMEMBER", + "SMEMBERS", + "SRANDMEMBER", + "STRLEN", + "SUNION", + "TTL", + "ZCARD", + "ZCOUNT", + "ZRANGE", + "ZSCORE", +]) + + +def cleanup_kwargs(**kwargs): + """ + Remove unsupported or disabled keys from kwargs + """ + connection_kwargs = { + k: v + for k, v in kwargs.items() + if k in REDIS_ALLOWED_KEYS and k not in KWARGS_DISABLED_KEYS + } + + return connection_kwargs + + +class ClusterParser(DefaultParser): + EXCEPTION_CLASSES = dict_merge( + DefaultParser.EXCEPTION_CLASSES, { + 'ASK': AskError, + 'TRYAGAIN': TryAgainError, + 'MOVED': MovedError, + 'CLUSTERDOWN': ClusterDownError, + 'CROSSSLOT': ClusterCrossSlotError, + 'MASTERDOWN': MasterDownError, + }) + + +class RedisCluster(ClusterCommands, object): + RedisClusterRequestTTL = 16 + + PRIMARIES = "all-primaries" + REPLICAS = "all-replicas" + ALL_NODES = "all-nodes" + RANDOM = "random" + + NODE_FLAGS = { + PRIMARIES, + REPLICAS, + ALL_NODES, + RANDOM + } + + COMMAND_FLAGS = dict_merge( + list_keys_to_dict( + [ + "CLIENT LIST", + "CLIENT SETNAME", + "CLIENT GETNAME", + "CONFIG GET", + "CONFIG SET", + "CONFIG REWRITE", + "CONFIG RESETSTAT", + "TIME", + "PUBSUB CHANNELS", + "PUBSUB NUMPAT", + "PUBSUB NUMSUB", + "PING", + "INFO", + "SHUTDOWN" + ], + ALL_NODES, + ), + list_keys_to_dict( + [ + "KEYS", + "SCAN", + "FLUSHALL", + "FLUSHDB", + "DBSIZE", + "BGSAVE", + "SLOWLOG GET", + "SLOWLOG LEN", + "SLOWLOG RESET", + "WAIT", + "TIME", + "SAVE", + "MEMORY PURGE", + "MEMORY MALLOC-STATS", + "MEMORY STATS", + "LASTSAVE", + "CLIENT TRACKINGINFO", + "CLIENT PAUSE", + "CLIENT UNPAUSE", + "CLIENT UNBLOCK", + "CLIENT ID", + "CLIENT REPLY", + "CLIENT GETREDIR", + "CLIENT INFO", + "CLIENT KILL" + ], + PRIMARIES, + ), + list_keys_to_dict( + [ + "READONLY", + "READWRITE", + ], + REPLICAS, + ), + list_keys_to_dict( + [ + "CLUSTER INFO", + "CLUSTER NODES", + "CLUSTER REPLICAS", + "CLUSTER SLOTS", + "CLUSTER COUNT-FAILURE-REPORTS", + "CLUSTER KEYSLOT", + "RANDOMKEY", + "COMMAND", + "COMMAND GETKEYS", + "DEBUG", + ], + RANDOM, + ), + list_keys_to_dict( + [ + "CLUSTER COUNTKEYSINSLOT", + "CLUSTER DELSLOTS", + "CLUSTER GETKEYSINSLOT", + "CLUSTER SETSLOT", + ], + SLOT_ID, + ), + ) + + CLUSTER_COMMANDS_RESPONSE_CALLBACKS = { + 'CLUSTER ADDSLOTS': bool, + 'CLUSTER COUNT-FAILURE-REPORTS': int, + 'CLUSTER COUNTKEYSINSLOT': int, + 'CLUSTER DELSLOTS': bool, + 'CLUSTER FAILOVER': bool, + 'CLUSTER FORGET': bool, + 'CLUSTER GETKEYSINSLOT': list, + 'CLUSTER KEYSLOT': int, + 'CLUSTER MEET': bool, + 'CLUSTER REPLICATE': bool, + 'CLUSTER RESET': bool, + 'CLUSTER SAVECONFIG': bool, + 'CLUSTER SET-CONFIG-EPOCH': bool, + 'CLUSTER SETSLOT': bool, + 'CLUSTER SLOTS': parse_cluster_slots, + 'ASKING': bool, + 'READONLY': bool, + 'READWRITE': bool, + } + + RESULT_CALLBACKS = dict_merge( + list_keys_to_dict([ + "PUBSUB NUMSUB", + ], parse_pubsub_numsub), + list_keys_to_dict([ + "PUBSUB NUMPAT", + ], lambda command, res: sum(list(res.values()))), + list_keys_to_dict([ + "KEYS", + "PUBSUB CHANNELS", + ], merge_result), + list_keys_to_dict([ + "PING", + "CONFIG SET", + "CONFIG REWRITE", + "CONFIG RESETSTAT", + "CLIENT SETNAME", + "BGSAVE", + "SLOWLOG RESET", + "SAVE", + "MEMORY PURGE", + "CLIENT PAUSE", + "CLIENT UNPAUSE", + ], lambda command, res: all(res.values()) if isinstance(res, dict) + else res), + list_keys_to_dict([ + "DBSIZE", + "WAIT", + ], lambda command, res: sum(res.values()) if isinstance(res, dict) + else res), + list_keys_to_dict([ + "CLIENT UNBLOCK", + ], lambda command, res: 1 if sum(res.values()) > 0 else 0) + ) + + def __init__( + self, + host=None, + port=6379, + startup_nodes=None, + cluster_error_retry_attempts=3, + require_full_coverage=True, + skip_full_coverage_check=False, + reinitialize_steps=10, + read_from_replicas=False, + url=None, + debug=False, + **kwargs + ): + """ + :startup_nodes: 'list[ClusterNode]' + List of nodes from which initial bootstrapping can be done + :host: 'str' + Can be used to point to a startup node + :port: 'int' + Can be used to point to a startup node + :require_full_coverage: 'bool' + If set to True, as it is by default, all slots must be covered. + If set to False and not all slots are covered, the instance + creation will succeed only if 'cluster-require-full-coverage' + configuration is set to 'no' in all of the cluster's nodes. + Otherwise, RedisClusterException will be thrown. + :skip_full_coverage_check: 'bool' + If require_full_coverage is set to False, a check of + cluster-require-full-coverage config will be executed against all + nodes. Set skip_full_coverage_check to True to skip this check. + Useful for clusters without the CONFIG command (like ElastiCache) + :read_from_replicas: 'bool' + Enable read from replicas in READONLY mode. You can read possibly + stale data. + When set to true, read commands will be assigned between the + primary and its replications in a Round-Robin manner. + :cluster_error_retry_attempts: 'int' + Retry command execution attempts when encountering ClusterDownError + or ConnectionError + :debug: + Add prints to debug the RedisCluster client + + :**kwargs: + Extra arguments that will be sent into Redis instance when created + (See Official redis-py doc for supported kwargs + [https://github.com/andymccurdy/redis-py/blob/master/redis/client.py]) + Some kwargs are not supported and will raise a + RedisClusterException: + - db (Redis do not support database SELECT in cluster mode) + """ + + if startup_nodes is None: + startup_nodes = [] + + if "db" in kwargs: + # Argument 'db' is not possible to use in cluster mode + raise RedisClusterException( + "Argument 'db' is not possible to use in cluster mode" + ) + + # Get the startup node/s + from_url = False + if url is not None: + from_url = True + url_options = parse_url(url) + if "path" in url_options: + raise RedisClusterException( + "RedisCluster does not currently support Unix Domain " + "Socket connections") + if "db" in url_options and url_options["db"] != 0: + # Argument 'db' is not possible to use in cluster mode + raise RedisClusterException( + "A ``db`` querystring option can only be 0 in cluster mode" + ) + kwargs.update(url_options) + startup_nodes.append(ClusterNode(kwargs['host'], kwargs['port'])) + elif host is not None and port is not None: + startup_nodes.append(ClusterNode(host, port)) + elif len(startup_nodes) == 0: + # No startup node was provided + raise RedisClusterException( + "RedisCluster requires at least one node to discover the " + "cluster. Please provide one of the followings:\n" + "1. host and port, for example:\n" + " RedisCluster(host='localhost', port=6379)\n" + "2. list of startup nodes, for example:\n" + " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," + " ClusterNode('localhost', 6378)])") + + # Update the connection arguments + # Whenever a new connection is established, RedisCluster's on_connect + # method should be run + # If the user passed on_connect function we'll save it and run it + # inside the RedisCluster.on_connect() function + self.user_on_connect_func = kwargs.pop("redis_connect_func", None) + kwargs.update({"redis_connect_func": self.on_connect}) + kwargs = cleanup_kwargs(**kwargs) + + self.encoder = Encoder( + kwargs.get("encoding", "utf-8"), + kwargs.get("encoding_errors", "strict"), + kwargs.get("decode_responses", False), + ) + self.cluster_error_retry_attempts = cluster_error_retry_attempts + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.node_flags = self.__class__.NODE_FLAGS.copy() + self.debug_mode = debug + self.read_from_replicas = read_from_replicas + self.reinitialize_counter = 0 + self.reinitialize_steps = reinitialize_steps + self.nodes_manager = None + self.nodes_manager = NodesManager( + startup_nodes=startup_nodes, + from_url=from_url, + require_full_coverage=require_full_coverage, + skip_full_coverage_check=skip_full_coverage_check, + **kwargs, + ) + + self.cluster_response_callbacks = CaseInsensitiveDict( + self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS) + self.result_callbacks = CaseInsensitiveDict( + self.__class__.RESULT_CALLBACKS) + self.commands_parser = CommandsParser(self) + self._lock = threading.Lock() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.close() + + def __del__(self): + self.close() + + def disconnect_connection_pools(self): + for node in self.get_nodes(): + if node.redis_connection: + node.redis_connection.connection_pool.disconnect() + + @classmethod + def from_url(cls, url, **kwargs): + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + return cls(url=url, **kwargs) + + def on_connect(self, connection): + """ + Initialize the connection, authenticate and select a database and send + READONLY if it is set during object initialization. + """ + connection.set_parser(ClusterParser) + connection.on_connect() + + if self.read_from_replicas: + # Sending READONLY command to server to configure connection as + # readonly. Since each cluster node may change its server type due + # to a failover, we should establish a READONLY connection + # regardless of the server type. If this is a primary connection, + # READONLY would not affect executing write commands. + connection.send_command('READONLY') + if str_if_bytes(connection.read_response()) != 'OK': + raise ConnectionError('READONLY command failed') + + if self.user_on_connect_func is not None: + self.user_on_connect_func(connection) + + def get_redis_connection(self, node): + if not node.redis_connection: + with self._lock: + if not node.redis_connection: + self.nodes_manager.create_redis_connections([node]) + return node.redis_connection + + def get_node(self, host=None, port=None, node_name=None): + return self.nodes_manager.get_node(host, port, node_name) + + def get_primaries(self): + return self.nodes_manager.get_nodes_by_server_type(PRIMARY) + + def get_replicas(self): + return self.nodes_manager.get_nodes_by_server_type(REPLICA) + + def get_random_node(self): + return random.choice(list(self.nodes_manager.nodes_cache.values())) + + def get_nodes(self): + return list(self.nodes_manager.nodes_cache.values()) + + def pubsub(self, node=None, host=None, port=None, **kwargs): + """ + Allows passing a ClusterNode, or host&port, to get a pubsub instance + connected to the specified node + """ + return ClusterPubSub(self, node=node, host=host, port=port, **kwargs) + + def pipeline(self, transaction=None, + shard_hint=None, read_from_replicas=False): + """ + Cluster impl: + Pipelines do not work in cluster mode the same way they + do in normal mode. Create a clone of this object so + that simulating pipelines will work correctly. Each + command will be called directly when used and + when calling execute() will only return the result stack. + """ + if shard_hint: + raise RedisClusterException( + "shard_hint is deprecated in cluster mode") + + if transaction: + raise RedisClusterException( + "transaction is deprecated in cluster mode") + + return ClusterPipeline( + nodes_manager=self.nodes_manager, + startup_nodes=self.nodes_manager.startup_nodes, + result_callbacks=self.result_callbacks, + cluster_response_callbacks=self.cluster_response_callbacks, + cluster_error_retry_attempts=self.cluster_error_retry_attempts, + read_from_replicas=read_from_replicas, + ) + + def _determine_nodes(self, *args, **kwargs): + command = args[0] + nodes_flag = kwargs.pop("nodes_flag", None) + if nodes_flag is not None: + # nodes flag passed by the user + command_flag = nodes_flag + else: + # get the predefined nodes group for this command + command_flag = self.command_flags.get(command) + + if command_flag == self.__class__.RANDOM: + return [self.get_random_node()] + elif command_flag == self.__class__.PRIMARIES: + return self.get_primaries() + elif command_flag == self.__class__.REPLICAS: + return self.get_replicas() + elif command_flag == self.__class__.ALL_NODES: + return self.get_nodes() + else: + # get the node that holds the key's slot + slot = self.determine_slot(*args) + return [self.nodes_manager. + get_node_from_slot(slot, self.read_from_replicas + and command in READ_COMMANDS)] + + def _should_reinitialized(self): + # In order not to reinitialize the cluster, the user can set + # reinitialize_steps to 0. + if self.reinitialize_steps == 0: + return False + else: + return self.reinitialize_counter % self.reinitialize_steps == 0 + + def keyslot(self, key): + """ + Calculate keyslot for a given key. + """ + k = self.encoder.encode(key) + return key_slot(k) + + def determine_slot(self, *args): + """ + figure out what slot based on command and args + """ + if self.command_flags.get(args[0]) == SLOT_ID: + # The command contains the slot ID + return args[1] + + redis_conn = self.get_random_node().redis_connection + keys = self.commands_parser.get_keys(redis_conn, *args) + if keys is None or len(keys) == 0: + raise RedisClusterException( + "No way to dispatch this command to Redis Cluster. " + "Missing key.\nYou can execute the command by specifying " + "target nodes.\nCommand: {0}".format(args) + ) + + if len(keys) > 1: + # multi-key command, we need to make sure all keys are mapped to + # the same slot + slots = {self.keyslot(key) for key in keys} + if len(slots) != 1: + raise RedisClusterException("{0} - all keys must map to the " + "same key slot".format(args[0])) + return slots.pop() + else: + # single key command + return self.keyslot(keys[0]) + + def reinitialize_caches(self): + self.nodes_manager.initialize() + + def _is_nodes_flag(self, target_nodes): + return isinstance(target_nodes, str) \ + and target_nodes in self.node_flags + + def _parse_target_nodes(self, target_nodes): + if isinstance(target_nodes, list): + nodes = target_nodes + elif isinstance(target_nodes, ClusterNode): + # Supports passing a single ClusterNode as a variable + nodes = [target_nodes] + elif isinstance(target_nodes, dict): + # Supports dictionaries of the format {node_name: node}. + # It enables to execute commands with multi nodes as follows: + # rc.cluster_save_config(rc.get_primaries()) + nodes = target_nodes.values() + else: + raise TypeError("target_nodes type can be one of the " + "followings: node_flag (PRIMARIES, " + "REPLICAS, RANDOM, ALL_NODES)," + "ClusterNode, list, or " + "dict. The passed type is {0}". + format(type(target_nodes))) + return nodes + + def execute_command(self, *args, **kwargs): + """ + Wrapper for ClusterDownError and ConnectionError error handling. + + It will try the number of times specified by the config option + "self.cluster_error_retry_attempts" which defaults to 3 unless manually + configured. + + If it reaches the number of times, the command will raise the exception + + Key argument :target_nodes: can be passed with the following types: + nodes_flag: PRIMARIES, REPLICAS, ALL_NODES, RANDOM + ClusterNode + list + dict + """ + target_nodes_specified = False + target_nodes = kwargs.pop("target_nodes", None) + if target_nodes is not None and not self._is_nodes_flag(target_nodes): + target_nodes = self._parse_target_nodes(target_nodes) + target_nodes_specified = True + # If ClusterDownError/ConnectionError were thrown, the nodes + # and slots cache were reinitialized. We will retry executing the + # command with the updated cluster setup only when the target nodes + # can be determined again with the new cache tables. Therefore, + # when target nodes were passed to this function, we cannot retry + # the command execution since the nodes may not be valid anymore + # after the tables were reinitialized. So in case of passed target + # nodes, retry_attempts will be set to 1. + retry_attempts = 1 if target_nodes_specified else \ + self.cluster_error_retry_attempts + exception = None + for _ in range(0, retry_attempts): + try: + res = {} + if not target_nodes_specified: + # Determine the nodes to execute the command on + target_nodes = self._determine_nodes( + *args, **kwargs, nodes_flag=target_nodes) + if not target_nodes: + raise RedisClusterException( + "No targets were found to execute" + " {} command on".format(args)) + for node in target_nodes: + res[node.name] = self._execute_command( + node, *args, **kwargs) + # Return the processed result + return self._process_result(args[0], res, **kwargs) + except (ClusterDownError, ConnectionError) as e: + # The nodes and slots cache were reinitialized. + # Try again with the new cluster setup. All other errors + # should be raised. + exception = e + + # If it fails the configured number of times then raise exception back + # to caller of this method + raise exception + + def _execute_command(self, target_node, *args, **kwargs): + """ + Send a command to a node in the cluster + """ + command = args[0] + redis_node = None + connection = None + redirect_addr = None + asking = False + moved = False + ttl = int(self.RedisClusterRequestTTL) + connection_error_retry_counter = 0 + + while ttl > 0: + ttl -= 1 + try: + if asking: + target_node = self.get_node(node_name=redirect_addr) + elif moved: + # MOVED occurred and the slots cache was updated, + # refresh the target node + slot = self.determine_slot(*args) + target_node = self.nodes_manager. \ + get_node_from_slot(slot, self.read_from_replicas and + command in READ_COMMANDS) + moved = False + + if self.debug_mode: + print("Executing command {0} on target node: {1} {2}". + format(command, target_node.server_type, + target_node.name)) + redis_node = self.get_redis_connection(target_node) + connection = get_connection(redis_node, *args, **kwargs) + if asking: + connection.send_command("ASKING") + redis_node.parse_response(connection, "ASKING", **kwargs) + asking = False + + connection.send_command(*args) + response = redis_node.parse_response(connection, command, + **kwargs) + if command in self.cluster_response_callbacks: + response = self.cluster_response_callbacks[command]( + response, **kwargs) + return response + + except (RedisClusterException, BusyLoadingError): + warnings.warn("RedisClusterException || BusyLoadingError") + raise + except ConnectionError: + warnings.warn("ConnectionError") + # ConnectionError can also be raised if we couldn't get a + # connection from the pool before timing out, so check that + # this is an actual connection before attempting to disconnect. + if connection is not None: + connection.disconnect() + connection_error_retry_counter += 1 + + # Give the node 0.25 seconds to get back up and retry again + # with same node and configuration. After 5 attempts then try + # to reinitialize the cluster and see if the nodes + # configuration has changed or not + if connection_error_retry_counter < 5: + time.sleep(0.25) + else: + # Hard force of reinitialize of the node/slots setup + # and try again with the new setup + self.nodes_manager.initialize() + raise + except TimeoutError: + warnings.warn("TimeoutError") + if connection is not None: + connection.disconnect() + + if ttl < self.RedisClusterRequestTTL / 2: + time.sleep(0.05) + except MovedError as e: + # First, we will try to patch the slots/nodes cache with the + # redirected node output and try again. If MovedError exceeds + # 'reinitialize_steps' number of times, we will force + # reinitializing the tables, and then try again. + # 'reinitialize_steps' counter will increase faster when the + # same client object is shared between multiple threads. To + # reduce the frequency you can set this variable in the + # RedisCluster constructor. + warnings.warn("MovedError") + self.reinitialize_counter += 1 + if self._should_reinitialized(): + self.nodes_manager.initialize() + else: + self.nodes_manager.update_moved_exception(e) + moved = True + except TryAgainError: + warnings.warn("TryAgainError") + + if ttl < self.RedisClusterRequestTTL / 2: + time.sleep(0.05) + except AskError as e: + warnings.warn("AskError") + + redirect_addr = get_node_name(host=e.host, port=e.port) + asking = True + except ClusterDownError as e: + warnings.warn("ClusterDownError") + # ClusterDownError can occur during a failover and to get + # self-healed, we will try to reinitialize the cluster layout + # and retry executing the command + time.sleep(0.05) + self.nodes_manager.initialize() + raise e + except ResponseError as e: + message = e.__str__() + warnings.warn("ResponseError: {0}".format(message)) + raise e + except BaseException as e: + warnings.warn("BaseException") + if connection: + connection.disconnect() + raise e + finally: + if connection is not None: + redis_node.connection_pool.release(connection) + + raise ClusterError("TTL exhausted.") + + def close(self): + try: + with self._lock: + if self.nodes_manager: + self.nodes_manager.close() + except AttributeError: + # RedisCluster's __init__ can fail before nodes_manager is set + pass + + def _process_result(self, command, res, **kwargs): + """ + Process the result of the executed command. + The function would return a dict or a single value. + + :type command: str + :type res: dict + + `res` should be in the following format: + Dict + """ + if command in self.result_callbacks: + return self.result_callbacks[command](command, res, **kwargs) + elif len(res) == 1: + # When we execute the command on a single node, we can + # remove the dictionary and return a single response + return list(res.values())[0] + else: + return res + + +class ClusterNode(object): + def __init__(self, host, port, server_type=None, redis_connection=None): + if host == 'localhost': + host = socket.gethostbyname(host) + + self.host = host + self.port = port + self.name = get_node_name(host, port) + self.server_type = server_type + self.redis_connection = redis_connection + + def __repr__(self): + return '[host={0},port={1},' \ + 'name={2},server_type={3},redis_connection={4}]' \ + .format(self.host, + self.port, + self.name, + self.server_type, + self.redis_connection) + + def __eq__(self, obj): + return isinstance(obj, ClusterNode) and obj.name == self.name + + +class LoadBalancer: + """ + Round-Robin Load Balancing + """ + + def __init__(self, start_index=0): + self.primary_to_idx = {} + self.start_index = start_index + + def get_server_index(self, primary, list_size): + server_index = self.primary_to_idx.setdefault(primary, + self.start_index) + # Update the index + self.primary_to_idx[primary] = (server_index + 1) % list_size + return server_index + + def reset(self): + self.primary_to_idx.clear() + + +class NodesManager: + def __init__(self, startup_nodes, from_url=False, + require_full_coverage=True, skip_full_coverage_check=False, + lock=None, **kwargs): + self.nodes_cache = {} + self.slots_cache = {} + self.startup_nodes = {} + self.populate_startup_nodes(startup_nodes) + self.from_url = from_url + self._require_full_coverage = require_full_coverage + self._skip_full_coverage_check = skip_full_coverage_check + self._moved_exception = None + self.connection_kwargs = kwargs + self.read_load_balancer = LoadBalancer() + if lock is None: + lock = threading.Lock() + self._lock = lock + self.initialize() + + def get_node(self, host=None, port=None, node_name=None): + if node_name is None and (host is None or port is None): + warnings.warn( + "get_node requires one of the followings: " + "1. node name " + "2. host and port" + ) + return None + if host is not None and port is not None: + if host == "localhost": + host = socket.gethostbyname(host) + node_name = get_node_name(host=host, port=port) + return self.nodes_cache.get(node_name) + + def update_moved_exception(self, exception): + self._moved_exception = exception + + def _update_moved_slots(self): + e = self._moved_exception + redirected_node = self.get_node(host=e.host, port=e.port) + if redirected_node: + if redirected_node.server_type is not PRIMARY: + # Update the node's server type + redirected_node.server_type = PRIMARY + else: + # This is a new node, we will add it to the nodes cache + redirected_node = ClusterNode(e.host, e.port, PRIMARY) + self.nodes_cache[redirected_node.name] = redirected_node + if redirected_node in self.slots_cache[e.slot_id]: + # The MOVED error resulted from a failover, and the new slot owner + # had previously been a replica. + old_primary = self.slots_cache[e.slot_id][0] + # Update the old primary to be a replica and add it to the end of + # the slot's node list + old_primary.server_type = REPLICA + self.slots_cache[e.slot_id].append(old_primary) + # Remove the old replica, which is now a primary, from the slot's + # node list + self.slots_cache[e.slot_id].remove(redirected_node) + # Override the old primary with the new one + self.slots_cache[e.slot_id][0] = redirected_node + else: + # The new slot owner is a new server, or a server from a different + # shard. We need to remove all current nodes from the slot's list + # (including replications) and add just the new node. + self.slots_cache[e.slot_id] = [redirected_node] + # Reset moved_exception + self._moved_exception = None + + def get_node_from_slot(self, slot, read_from_replicas=False, + server_type=None): + """ + Gets a node that servers this hash slot + """ + if self._moved_exception: + with self._lock: + if self._moved_exception: + self._update_moved_slots() + + if self.slots_cache.get(slot) is None or \ + len(self.slots_cache[slot]) == 0: + raise SlotNotCoveredError( + 'Slot "{0}" not covered by the cluster. ' + '"require_full_coverage={1}"'.format( + slot, self._require_full_coverage) + ) + + if read_from_replicas: + # get the server index in a Round-Robin manner + primary_name = self.slots_cache[slot][0].name + node_idx = self.read_load_balancer.get_server_index( + primary_name, len(self.slots_cache[slot])) + elif ( + server_type is None + or server_type == PRIMARY + or len(self.slots_cache[slot]) == 1 + ): + # return a primary + node_idx = 0 + else: + # return a replica + # randomly choose one of the replicas + node_idx = random.randint( + 1, len(self.slots_cache[slot]) - 1) + + return self.slots_cache[slot][node_idx] + + def get_nodes_by_server_type(self, server_type): + return [ + node + for node in self.nodes_cache.values() + if node.server_type == server_type + ] + + def populate_startup_nodes(self, nodes): + """ + Populate all startup nodes and filters out any duplicates + """ + for n in nodes: + self.startup_nodes[n.name] = n + + def cluster_require_full_coverage(self, cluster_nodes): + """ + if exists 'cluster-require-full-coverage no' config on redis servers, + then even all slots are not covered, cluster still will be able to + respond + """ + + def node_require_full_coverage(node): + try: + return ("yes" in node.redis_connection.config_get( + "cluster-require-full-coverage").values() + ) + except ConnectionError: + return False + except Exception as e: + raise RedisClusterException( + 'ERROR sending "config get cluster-require-full-coverage"' + ' command to redis server: {0}, {1}'.format(node.name, e) + ) + + # at least one node should have cluster-require-full-coverage yes + return any(node_require_full_coverage(node) + for node in cluster_nodes.values()) + + def check_slots_coverage(self, slots_cache): + # Validate if all slots are covered or if we should try next + # startup node + for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + if i not in slots_cache: + return False + return True + + def create_redis_connections(self, nodes): + """ + This function will create a redis connection to all nodes in :nodes: + """ + for node in nodes: + if node.redis_connection is None: + node.redis_connection = self.create_redis_node( + host=node.host, + port=node.port, + **self.connection_kwargs, + ) + + def create_redis_node(self, host, port, **kwargs): + if self.from_url: + # Create a redis node with a costumed connection pool + kwargs.update({"host": host}) + kwargs.update({"port": port}) + connection_pool = ConnectionPool(**kwargs) + r = Redis( + connection_pool=connection_pool + ) + else: + r = Redis( + host=host, + port=port, + **kwargs + ) + return r + + def initialize(self): + """ + Initializes the nodes cache, slots cache and redis connections. + :startup_nodes: + Responsible for discovering other nodes in the cluster + """ + self.reset() + tmp_nodes_cache = {} + tmp_slots = {} + disagreements = [] + startup_nodes_reachable = False + kwargs = self.connection_kwargs + for startup_node in self.startup_nodes.values(): + try: + if startup_node.redis_connection: + r = startup_node.redis_connection + else: + # Create a new Redis connection and let Redis decode the + # responses so we won't need to handle that + copy_kwargs = copy.deepcopy(kwargs) + copy_kwargs.update({"decode_responses": True}) + copy_kwargs.update({"encoding": "utf-8"}) + r = self.create_redis_node( + startup_node.host, startup_node.port, **copy_kwargs) + self.startup_nodes[startup_node.name].redis_connection = r + cluster_slots = r.execute_command("CLUSTER SLOTS") + startup_nodes_reachable = True + except (ConnectionError, TimeoutError): + continue + except ResponseError as e: + warnings.warn( + 'ReseponseError sending "cluster slots" to redis server') + + # Isn't a cluster connection, so it won't parse these + # exceptions automatically + message = e.__str__() + if "CLUSTERDOWN" in message or "MASTERDOWN" in message: + continue + else: + raise RedisClusterException( + 'ERROR sending "cluster slots" command to redis ' + 'server: {0}. error: {1}'.format( + startup_node, message) + ) + except Exception as e: + message = e.__str__() + raise RedisClusterException( + 'ERROR sending "cluster slots" command to redis ' + 'server: {0}. error: {1}'.format( + startup_node, message) + ) + + # If there's only one server in the cluster, its ``host`` is '' + # Fix it to the host in startup_nodes + if (len(cluster_slots) == 1 + and len(cluster_slots[0][2][0]) == 0 + and len(self.startup_nodes) == 1): + cluster_slots[0][2][0] = startup_node.host + + for slot in cluster_slots: + primary_node = slot[2] + host = primary_node[0] + if host == "": + host = startup_node.host + port = int(primary_node[1]) + + target_node = tmp_nodes_cache.get(get_node_name(host, port)) + if target_node is None: + target_node = ClusterNode(host, port, PRIMARY) + # add this node to the nodes cache + tmp_nodes_cache[target_node.name] = target_node + + for i in range(int(slot[0]), int(slot[1]) + 1): + if i not in tmp_slots: + tmp_slots[i] = [] + tmp_slots[i].append(target_node) + replica_nodes = [slot[j] for j in range(3, len(slot))] + + for replica_node in replica_nodes: + host = replica_node[0] + port = replica_node[1] + + target_replica_node = tmp_nodes_cache.get( + get_node_name(host, port)) + if target_replica_node is None: + target_replica_node = ClusterNode( + host, port, REPLICA) + tmp_slots[i].append(target_replica_node) + # add this node to the nodes cache + tmp_nodes_cache[ + target_replica_node.name + ] = target_replica_node + else: + # Validate that 2 nodes want to use the same slot cache + # setup + if tmp_slots[i][0].name != target_node.name: + disagreements.append( + '{0} vs {1} on slot: {2}'.format( + tmp_slots[i][0].name, target_node.name, i) + ) + + if len(disagreements) > 5: + raise RedisClusterException( + 'startup_nodes could not agree on a valid' + ' slots cache: {0}'.format( + ", ".join(disagreements)) + ) + + if not startup_nodes_reachable: + raise RedisClusterException( + "Redis Cluster cannot be connected. Please provide at least " + "one reachable node. " + ) + + # Create Redis connections to all nodes + self.create_redis_connections(list(tmp_nodes_cache.values())) + + fully_covered = self.check_slots_coverage(tmp_slots) + if not fully_covered: + if self._require_full_coverage: + # Despite the requirement that the slots be covered, there + # isn't a full coverage + raise RedisClusterException( + 'All slots are not covered after query all startup_nodes.' + ' {0} of {1} covered...'.format( + len(self.slots_cache), REDIS_CLUSTER_HASH_SLOTS) + ) + else: + # The user set require_full_coverage to False. + # In case of full coverage requirement in the cluster's Redis + # configurations, we will raise an exception. Otherwise, we may + # continue with partial coverage. + # see Redis Cluster configuration parameters in + # https://redis.io/topics/cluster-tutorial + if not self._skip_full_coverage_check and \ + self.cluster_require_full_coverage(tmp_nodes_cache): + raise RedisClusterException( + 'Not all slots are covered but the cluster\'s ' + 'configuration requires full coverage. Set ' + 'cluster-require-full-coverage configuration to no on ' + 'all of the cluster nodes if you wish the cluster to ' + 'be able to serve without being fully covered.' + ' {0} of {1} covered...'.format( + len(self.slots_cache), REDIS_CLUSTER_HASH_SLOTS) + ) + + # Set the tmp variables to the real variables + self.nodes_cache = tmp_nodes_cache + self.slots_cache = tmp_slots + # Populate the startup nodes with all discovered nodes + self.populate_startup_nodes(self.nodes_cache.values()) + + def close(self): + for node in self.nodes_cache.values(): + if node.redis_connection: + node.redis_connection.close() + + def reset(self): + if self.read_load_balancer is not None: + self.read_load_balancer.reset() + + +class ClusterPubSub(PubSub): + """ + Wrapper for PubSub class. + + IMPORTANT: before using ClusterPubSub, read about the known limitations + with pubsub in Cluster mode and learn how to workaround them: + https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html + """ + + def __init__(self, redis_cluster, node=None, host=None, port=None, + **kwargs): + """ + When a pubsub instance is created without specifying a node, a single + node will be transparently chosen for the pubsub connection on the + first command execution. The node will be determined by: + 1. Hashing the channel name in the request to find its keyslot + 2. Selecting a node that handles the keyslot: If read_from_replicas is + set to true, a replica can be selected. + """ + self.node = None + connection_pool = None + if host is not None and port is not None: + node = redis_cluster.get_node(host=host, port=port) + self.node = node + if node is not None: + if not isinstance(node, ClusterNode): + raise DataError("'node' must be a ClusterNode") + connection_pool = redis_cluster.get_redis_connection(node). \ + connection_pool + self.cluster = redis_cluster + super().__init__(**kwargs, connection_pool=connection_pool, + encoder=redis_cluster.encoder) + + def execute_command(self, *args, **kwargs): + """ + Execute a publish/subscribe command. + + Taken code from redis-py and tweak to make it work within a cluster. + """ + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already + # subscribed to one or more channels + + if self.connection is None: + if self.connection_pool is None: + if len(args) > 1: + # Hash the first channel and get one of the nodes holding + # this slot + channel = args[1] + slot = self.cluster.keyslot(channel) + node = self.cluster.nodes_manager. \ + get_node_from_slot(slot, self.cluster. + read_from_replicas) + else: + # Get a random node + node = self.cluster.get_random_node() + self.node = node + redis_connection = self.cluster.get_redis_connection(node) + self.connection_pool = redis_connection.connection_pool + self.connection = self.connection_pool.get_connection( + 'pubsub', + self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection.register_connect_callback(self.on_connect) + connection = self.connection + self._execute(connection, connection.send_command, *args) + + def get_redis_connection(self): + """ + Get the Redis connection of the pubsub connected node. + """ + if self.node is not None: + return self.node.redis_connection + + +ERRORS_ALLOW_RETRY = (ConnectionError, TimeoutError, + MovedError, AskError, TryAgainError) + + +class ClusterPipeline(RedisCluster): + """ + Support for Redis pipeline + in cluster mode + """ + + def __init__(self, nodes_manager, result_callbacks=None, + cluster_response_callbacks=None, startup_nodes=None, + read_from_replicas=False, cluster_error_retry_attempts=3, + debug=False, **kwargs): + """ + """ + self.command_stack = [] + self.debug_mode = debug + self.nodes_manager = nodes_manager + self.refresh_table_asap = False + self.result_callbacks = (result_callbacks or + self.__class__.RESULT_CALLBACKS.copy()) + self.startup_nodes = startup_nodes if startup_nodes else [] + self.read_from_replicas = read_from_replicas + self.command_flags = self.__class__.COMMAND_FLAGS.copy() + self.cluster_response_callbacks = cluster_response_callbacks + self.cluster_error_retry_attempts = cluster_error_retry_attempts + + self.encoder = Encoder( + kwargs.get("encoding", "utf-8"), + kwargs.get("encoding_errors", "strict"), + kwargs.get("decode_responses", False), + ) + + # The commands parser refers to the parent + # so that we don't push the COMMAND command + # onto the stack + self.commands_parser = CommandsParser(super()) + + def __repr__(self): + """ + """ + return "{0}".format(type(self).__name__) + + def __enter__(self): + """ + """ + return self + + def __exit__(self, exc_type, exc_value, traceback): + """ + """ + self.reset() + + def __del__(self): + try: + self.reset() + except Exception: + pass + + def __len__(self): + """ + """ + return len(self.command_stack) + + def __nonzero__(self): + "Pipeline instances should always evaluate to True on Python 2.7" + return True + + def __bool__(self): + "Pipeline instances should always evaluate to True on Python 3+" + return True + + def execute_command(self, *args, **kwargs): + """ + """ + return self.pipeline_execute_command(*args, **kwargs) + + def pipeline_execute_command(self, *args, **options): + """ + """ + self.command_stack.append( + PipelineCommand(args, options, len(self.command_stack))) + return self + + def raise_first_error(self, stack): + """ + """ + for c in stack: + r = c.result + if isinstance(r, Exception): + self.annotate_exception(r, c.position + 1, c.args) + raise r + + def annotate_exception(self, exception, number, command): + """ + """ + cmd = ' '.join(map(safe_str, command)) + msg = 'Command # %d (%s) of pipeline caused error: %s' % ( + number, cmd, exception.args[0]) + exception.args = (msg,) + exception.args[1:] + + def execute(self, raise_on_error=True): + """ + """ + stack = self.command_stack + + if not stack: + return [] + + try: + return self.send_cluster_commands(stack, raise_on_error) + finally: + self.reset() + + def reset(self): + """ + Reset back to empty pipeline. + """ + self.command_stack = [] + + self.scripts = set() + + # TODO: Implement + # make sure to reset the connection state in the event that we were + # watching something + # if self.watching and self.connection: + # try: + # # call this manually since our unwatch or + # # immediate_execute_command methods can call reset() + # self.connection.send_command('UNWATCH') + # self.connection.read_response() + # except ConnectionError: + # # disconnect will also remove any previous WATCHes + # self.connection.disconnect() + + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + + # TODO: Implement + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + # if self.connection: + # self.connection_pool.release(self.connection) + # self.connection = None + + def send_cluster_commands(self, stack, + raise_on_error=True, allow_redirections=True): + """ + Wrapper for CLUSTERDOWN error handling. + + If the cluster reports it is down it is assumed that: + - connection_pool was disconnected + - connection_pool was reseted + - refereh_table_asap set to True + + It will try the number of times specified by + the config option "self.cluster_error_retry_attempts" + which defaults to 3 unless manually configured. + + If it reaches the number of times, the command will + raises ClusterDownException. + """ + for _ in range(0, self.cluster_error_retry_attempts): + try: + return self._send_cluster_commands( + stack, + raise_on_error=raise_on_error, + allow_redirections=allow_redirections, + ) + except ClusterDownError: + # Try again with the new cluster setup. All other errors + # should be raised. + pass + + # If it fails the configured number of times then raise + # exception back to caller of this method + raise ClusterDownError( + "CLUSTERDOWN error. Unable to rebuild the cluster") + + def _send_cluster_commands(self, stack, + raise_on_error=True, + allow_redirections=True): + """ + Send a bunch of cluster commands to the redis cluster. + + `allow_redirections` If the pipeline should follow + `ASK` & `MOVED` responses automatically. If set + to false it will raise RedisClusterException. + """ + # the first time sending the commands we send all of + # the commands that were queued up. + # if we have to run through it again, we only retry + # the commands that failed. + attempt = sorted(stack, key=lambda x: x.position) + + # build a list of node objects based on node names we need to + nodes = {} + + # as we move through each command that still needs to be processed, + # we figure out the slot number that command maps to, then from + # the slot determine the node. + for c in attempt: + # refer to our internal node -> slot table that + # tells us where a given + # command should route to. + slot = self.determine_slot(*c.args) + node = self.nodes_manager.get_node_from_slot( + slot, self.read_from_replicas and c.args[0] in READ_COMMANDS) + + # now that we know the name of the node + # ( it's just a string in the form of host:port ) + # we can build a list of commands for each node. + node_name = node.name + if node_name not in nodes: + redis_node = self.get_redis_connection(node) + connection = get_connection(redis_node, c.args) + nodes[node_name] = NodeCommands(redis_node.parse_response, + redis_node.connection_pool, + connection) + + nodes[node_name].append(c) + + # send the commands in sequence. + # we write to all the open sockets for each node first, + # before reading anything + # this allows us to flush all the requests out across the + # network essentially in parallel + # so that we can read them all in parallel as they come back. + # we dont' multiplex on the sockets as they come available, + # but that shouldn't make too much difference. + node_commands = nodes.values() + for n in node_commands: + n.write() + + for n in node_commands: + n.read() + + # release all of the redis connections we allocated earlier + # back into the connection pool. + # we used to do this step as part of a try/finally block, + # but it is really dangerous to + # release connections back into the pool if for some + # reason the socket has data still left in it + # from a previous operation. The write and + # read operations already have try/catch around them for + # all known types of errors including connection + # and socket level errors. + # So if we hit an exception, something really bad + # happened and putting any oF + # these connections back into the pool is a very bad idea. + # the socket might have unread buffer still sitting in it, + # and then the next time we read from it we pass the + # buffered result back from a previous command and + # every single request after to that connection will always get + # a mismatched result. + for n in nodes.values(): + n.connection_pool.release(n.connection) + + # if the response isn't an exception it is a + # valid response from the node + # we're all done with that command, YAY! + # if we have more commands to attempt, we've run into problems. + # collect all the commands we are allowed to retry. + # (MOVED, ASK, or connection errors or timeout errors) + attempt = sorted([c for c in attempt + if isinstance(c.result, ERRORS_ALLOW_RETRY)], + key=lambda x: x.position) + if attempt and allow_redirections: + # RETRY MAGIC HAPPENS HERE! + # send these remaing comamnds one at a time using `execute_command` + # in the main client. This keeps our retry logic + # in one place mostly, + # and allows us to be more confident in correctness of behavior. + # at this point any speed gains from pipelining have been lost + # anyway, so we might as well make the best + # attempt to get the correct behavior. + # + # The client command will handle retries for each + # individual command sequentially as we pass each + # one into `execute_command`. Any exceptions + # that bubble out should only appear once all + # retries have been exhausted. + # + # If a lot of commands have failed, we'll be setting the + # flag to rebuild the slots table from scratch. + # So MOVED errors should correct themselves fairly quickly. + self.connection_pool.nodes. \ + increment_reinitialize_counter(len(attempt)) + for c in attempt: + try: + # send each command individually like we + # do in the main client. + c.result = super(ClusterPipeline, self). \ + execute_command(*c.args, **c.options) + except RedisError as e: + c.result = e + + # turn the response back into a simple flat array that corresponds + # to the sequence of commands issued in the stack in pipeline.execute() + response = [c.result for c in sorted(stack, key=lambda x: x.position)] + + if raise_on_error: + self.raise_first_error(stack) + + return response + + def _fail_on_redirect(self, allow_redirections): + """ + """ + if not allow_redirections: + raise RedisClusterException( + "ASK & MOVED redirection not allowed in this pipeline") + + def multi(self): + """ + """ + raise RedisClusterException("method multi() is not implemented") + + def immediate_execute_command(self, *args, **options): + """ + """ + raise RedisClusterException( + "method immediate_execute_command() is not implemented") + + def _execute_transaction(self, *args, **kwargs): + """ + """ + raise RedisClusterException( + "method _execute_transaction() is not implemented") + + def load_scripts(self): + """ + """ + raise RedisClusterException( + "method load_scripts() is not implemented") + + def watch(self, *names): + """ + """ + raise RedisClusterException("method watch() is not implemented") + + def unwatch(self): + """ + """ + raise RedisClusterException("method unwatch() is not implemented") + + def script_load_for_pipeline(self, *args, **kwargs): + """ + """ + raise RedisClusterException( + "method script_load_for_pipeline() is not implemented") + + def delete(self, *names): + """ + "Delete a key specified by ``names``" + """ + if len(names) != 1: + raise RedisClusterException( + "deleting multiple keys is not " + "implemented in pipeline command") + + return self.execute_command('DEL', names[0]) + + +def block_pipeline_command(func): + """ + Prints error because some pipelined commands should + be blocked when running in cluster-mode + """ + + def inner(*args, **kwargs): + raise RedisClusterException( + "ERROR: Calling pipelined function {0} is blocked when " + "running redis in cluster mode...".format(func.__name__)) + + return inner + + +# Blocked pipeline commands +ClusterPipeline.bitop = block_pipeline_command(RedisCluster.bitop) +ClusterPipeline.brpoplpush = block_pipeline_command(RedisCluster.brpoplpush) +ClusterPipeline.client_getname = \ + block_pipeline_command(RedisCluster.client_getname) +ClusterPipeline.client_list = block_pipeline_command(RedisCluster.client_list) +ClusterPipeline.client_setname = \ + block_pipeline_command(RedisCluster.client_setname) +ClusterPipeline.config_set = block_pipeline_command(RedisCluster.config_set) +ClusterPipeline.dbsize = block_pipeline_command(RedisCluster.dbsize) +ClusterPipeline.flushall = block_pipeline_command(RedisCluster.flushall) +ClusterPipeline.flushdb = block_pipeline_command(RedisCluster.flushdb) +ClusterPipeline.keys = block_pipeline_command(RedisCluster.keys) +ClusterPipeline.mget = block_pipeline_command(RedisCluster.mget) +ClusterPipeline.move = block_pipeline_command(RedisCluster.move) +ClusterPipeline.mset = block_pipeline_command(RedisCluster.mset) +ClusterPipeline.msetnx = block_pipeline_command(RedisCluster.msetnx) +ClusterPipeline.pfmerge = block_pipeline_command(RedisCluster.pfmerge) +ClusterPipeline.pfcount = block_pipeline_command(RedisCluster.pfcount) +ClusterPipeline.ping = block_pipeline_command(RedisCluster.ping) +ClusterPipeline.publish = block_pipeline_command(RedisCluster.publish) +ClusterPipeline.randomkey = block_pipeline_command(RedisCluster.randomkey) +ClusterPipeline.rename = block_pipeline_command(RedisCluster.rename) +ClusterPipeline.renamenx = block_pipeline_command(RedisCluster.renamenx) +ClusterPipeline.rpoplpush = block_pipeline_command(RedisCluster.rpoplpush) +ClusterPipeline.scan = block_pipeline_command(RedisCluster.scan) +ClusterPipeline.sdiff = block_pipeline_command(RedisCluster.sdiff) +ClusterPipeline.sdiffstore = block_pipeline_command(RedisCluster.sdiffstore) +ClusterPipeline.sinter = block_pipeline_command(RedisCluster.sinter) +ClusterPipeline.sinterstore = block_pipeline_command(RedisCluster.sinterstore) +ClusterPipeline.smove = block_pipeline_command(RedisCluster.smove) +ClusterPipeline.sort = block_pipeline_command(RedisCluster.sort) +ClusterPipeline.sunion = block_pipeline_command(RedisCluster.sunion) +ClusterPipeline.sunionstore = block_pipeline_command(RedisCluster.sunionstore) + + +class PipelineCommand(object): + """ + """ + + def __init__(self, args, options=None, position=None): + self.args = args + if options is None: + options = {} + self.options = options + self.position = position + self.result = None + self.node = None + self.asking = False + + +class NodeCommands(object): + """ + """ + + def __init__(self, parse_response, connection_pool, connection): + """ + """ + self.parse_response = parse_response + self.connection_pool = connection_pool + self.connection = connection + self.commands = [] + + def append(self, c): + """ + """ + self.commands.append(c) + + def write(self): + """ + Code borrowed from Redis so it can be fixed + """ + connection = self.connection + commands = self.commands + + # We are going to clobber the commands with the write, so go ahead + # and ensure that nothing is sitting there from a previous run. + for c in commands: + c.result = None + + # build up all commands into a single request to increase network perf + # send all the commands and catch connection and timeout errors. + try: + connection.send_packed_command( + connection.pack_commands([c.args for c in commands])) + except (ConnectionError, TimeoutError) as e: + for c in commands: + c.result = e + + def read(self): + """ + """ + connection = self.connection + for c in self.commands: + + # if there is a result on this command, + # it means we ran into an exception + # like a connection error. Trying to parse + # a response on a connection that + # is no longer open will result in a + # connection error raised by redis-py. + # but redis-py doesn't check in parse_response + # that the sock object is + # still set and if you try to + # read from a closed connection, it will + # result in an AttributeError because + # it will do a readline() call on None. + # This can have all kinds of nasty side-effects. + # Treating this case as a connection error + # is fine because it will dump + # the connection object back into the + # pool and on the next write, it will + # explicitly open the connection and all will be well. + if c.result is None: + try: + c.result = self.parse_response( + connection, c.args[0], **c.options) + except (ConnectionError, TimeoutError) as e: + for c in self.commands: + c.result = e + return + except RedisError: + c.result = sys.exc_info()[1] diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index f1ddaaabc1..60f13d8d35 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -2,9 +2,13 @@ from .redismodules import RedisModuleCommands from .helpers import list_or_args from .sentinel import SentinelCommands +from .cluster import ClusterCommands +from .parser import CommandsParser __all__ = [ 'CoreCommands', + 'ClusterCommands', + 'CommandsParser', 'RedisModuleCommands', 'SentinelCommands', 'list_or_args' diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py new file mode 100644 index 0000000000..d3c2e8572b --- /dev/null +++ b/redis/commands/cluster.py @@ -0,0 +1,801 @@ +from redis.exceptions import ( + ConnectionError, + DataError, + RedisError, +) +from redis.crc import key_slot +from .core import DataAccessCommands, PubSubCommands +from .helpers import list_or_args + + +class ClusterMultiKeyCommands: + """ + A class containing commands that handle more than one key + """ + + def _partition_keys_by_slot(self, keys): + """ + Split keys into a dictionary that maps a slot to + a list of keys. + """ + slots_to_keys = {} + for key in keys: + k = self.encoder.encode(key) + slot = key_slot(k) + slots_to_keys.setdefault(slot, []).append(key) + + return slots_to_keys + + def mget_nonatomic(self, keys, *args): + """ + Splits the keys into different slots and then calls MGET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. + + Returns a list of values ordered identically to ``keys`` + """ + + from redis.client import EMPTY_RESPONSE + options = {} + if not args: + options[EMPTY_RESPONSE] = [] + + # Concatenate all keys into a list + keys = list_or_args(keys, args) + # Split keys into slots + slots_to_keys = self._partition_keys_by_slot(keys) + + # Call MGET for every slot and concatenate + # the results + # We must make sure that the keys are returned in order + all_results = {} + for slot_keys in slots_to_keys.values(): + slot_values = self.execute_command( + 'MGET', *slot_keys, **options) + + slot_results = dict(zip(slot_keys, slot_values)) + all_results.update(slot_results) + + # Sort the results + vals_in_order = [all_results[key] for key in keys] + return vals_in_order + + def mset_nonatomic(self, mapping): + """ + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). + + Splits the keys into different slots and then calls MSET + for the keys of every slot. This operation will not be atomic + if keys belong to more than one slot. + """ + + # Partition the keys by slot + slots_to_pairs = {} + for pair in mapping.items(): + # encode the key + k = self.encoder.encode(pair[0]) + slot = key_slot(k) + slots_to_pairs.setdefault(slot, []).extend(pair) + + # Call MSET for every slot and concatenate + # the results (one result per slot) + res = [] + for pairs in slots_to_pairs.values(): + res.append(self.execute_command('MSET', *pairs)) + + return res + + def _split_command_across_slots(self, command, *keys): + """ + Runs the given command once for the keys + of each slot. Returns the sum of the return values. + """ + # Partition the keys by slot + slots_to_keys = self._partition_keys_by_slot(keys) + + # Sum up the reply from each command + total = 0 + for slot_keys in slots_to_keys.values(): + total += self.execute_command(command, *slot_keys) + + return total + + def exists(self, *keys): + """ + Returns the number of ``names`` that exist in the + whole cluster. The keys are first split up into slots + and then an EXISTS command is sent for every slot + """ + return self._split_command_across_slots('EXISTS', *keys) + + def delete(self, *keys): + """ + Deletes the given keys in the cluster. + The keys are first split up into slots + and then an DEL command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were deleted. + """ + return self._split_command_across_slots('DEL', *keys) + + def touch(self, *keys): + """ + Updates the last access time of given keys across the + cluster. + + The keys are first split up into slots + and then an TOUCH command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were touched. + """ + return self._split_command_across_slots('TOUCH', *keys) + + def unlink(self, *keys): + """ + Remove the specified keys in a different thread. + + The keys are first split up into slots + and then an TOUCH command is sent for every slot + + Non-existant keys are ignored. + Returns the number of keys that were unlinked. + """ + return self._split_command_across_slots('UNLINK', *keys) + + +class ClusterManagementCommands: + def bgsave(self, schedule=True, target_nodes=None): + """ + Tell the Redis server to save its data to disk. Unlike save(), + this method is asynchronous and returns immediately. + """ + pieces = [] + if schedule: + pieces.append("SCHEDULE") + return self.execute_command('BGSAVE', + *pieces, + target_nodes=target_nodes) + + def client_getname(self, target_nodes=None): + """ + Returns the current connection name from all nodes. + The result will be a dictionary with the IP and + connection name. + """ + return self.execute_command('CLIENT GETNAME', + target_nodes=target_nodes) + + def client_getredir(self, target_nodes=None): + """Returns the ID (an integer) of the client to whom we are + redirecting tracking notifications. + + see: https://redis.io/commands/client-getredir + """ + return self.execute_command('CLIENT GETREDIR', + target_nodes=target_nodes) + + def client_id(self, target_nodes=None): + """Returns the current connection id""" + return self.execute_command('CLIENT ID', + target_nodes=target_nodes) + + def client_info(self, target_nodes=None): + """ + Returns information and statistics about the current + client connection. + """ + return self.execute_command('CLIENT INFO', + target_nodes=target_nodes) + + def client_kill_filter(self, _id=None, _type=None, addr=None, + skipme=None, laddr=None, user=None, + target_nodes=None): + """ + Disconnects client(s) using a variety of filter options + :param id: Kills a client by its unique ID field + :param type: Kills a client by type where type is one of 'normal', + 'master', 'slave' or 'pubsub' + :param addr: Kills a client by its 'address:port' + :param skipme: If True, then the client calling the command + will not get killed even if it is identified by one of the filter + options. If skipme is not provided, the server defaults to skipme=True + :param laddr: Kills a client by its 'local (bind) address:port' + :param user: Kills a client for a specific user name + """ + args = [] + if _type is not None: + client_types = ('normal', 'master', 'slave', 'pubsub') + if str(_type).lower() not in client_types: + raise DataError("CLIENT KILL type must be one of %r" % ( + client_types,)) + args.extend((b'TYPE', _type)) + if skipme is not None: + if not isinstance(skipme, bool): + raise DataError("CLIENT KILL skipme must be a bool") + if skipme: + args.extend((b'SKIPME', b'YES')) + else: + args.extend((b'SKIPME', b'NO')) + if _id is not None: + args.extend((b'ID', _id)) + if addr is not None: + args.extend((b'ADDR', addr)) + if laddr is not None: + args.extend((b'LADDR', laddr)) + if user is not None: + args.extend((b'USER', user)) + if not args: + raise DataError("CLIENT KILL ... ... " + " must specify at least one filter") + return self.execute_command('CLIENT KILL', *args, + target_nodes=target_nodes) + + def client_kill(self, address, target_nodes=None): + "Disconnects the client at ``address`` (ip:port)" + return self.execute_command('CLIENT KILL', address, + target_nodes=target_nodes) + + def client_list(self, _type=None, target_nodes=None): + """ + Returns a list of currently connected clients to the entire cluster. + If type of client specified, only that type will be returned. + :param _type: optional. one of the client types (normal, master, + replica, pubsub) + """ + if _type is not None: + client_types = ('normal', 'master', 'replica', 'pubsub') + if str(_type).lower() not in client_types: + raise DataError("CLIENT LIST _type must be one of %r" % ( + client_types,)) + return self.execute_command('CLIENT LIST', + b'TYPE', + _type, + target_noes=target_nodes) + return self.execute_command('CLIENT LIST', + target_nodes=target_nodes) + + def client_pause(self, timeout, target_nodes=None): + """ + Suspend all the Redis clients for the specified amount of time + :param timeout: milliseconds to pause clients + """ + if not isinstance(timeout, int): + raise DataError("CLIENT PAUSE timeout must be an integer") + return self.execute_command('CLIENT PAUSE', str(timeout), + target_nodes=target_nodes) + + def client_reply(self, reply, target_nodes=None): + """Enable and disable redis server replies. + ``reply`` Must be ON OFF or SKIP, + ON - The default most with server replies to commands + OFF - Disable server responses to commands + SKIP - Skip the response of the immediately following command. + + Note: When setting OFF or SKIP replies, you will need a client object + with a timeout specified in seconds, and will need to catch the + TimeoutError. + The test_client_reply unit test illustrates this, and + conftest.py has a client with a timeout. + See https://redis.io/commands/client-reply + """ + replies = ['ON', 'OFF', 'SKIP'] + if reply not in replies: + raise DataError('CLIENT REPLY must be one of %r' % replies) + return self.execute_command("CLIENT REPLY", reply, + target_nodes=target_nodes) + + def client_setname(self, name, target_nodes=None): + "Sets the current connection name" + return self.execute_command('CLIENT SETNAME', name, + target_nodes=target_nodes) + + def client_trackinginfo(self, target_nodes=None): + """ + Returns the information about the current client connection's + use of the server assisted client side cache. + See https://redis.io/commands/client-trackinginfo + """ + return self.execute_command('CLIENT TRACKINGINFO', + target_nodes=target_nodes) + + def client_unblock(self, client_id, error=False, target_nodes=None): + """ + Unblocks a connection by its client id. + If ``error`` is True, unblocks the client with a special error message. + If ``error`` is False (default), the client is unblocked using the + regular timeout mechanism. + """ + args = ['CLIENT UNBLOCK', int(client_id)] + if error: + args.append(b'ERROR') + return self.execute_command(*args, target_nodes=target_nodes) + + def client_unpause(self, target_nodes=None): + """ + Unpause all redis clients + """ + return self.execute_command('CLIENT UNPAUSE', + target_nodes=target_nodes) + + def config_get(self, pattern="*", target_nodes=None): + """Return a dictionary of configuration based on the ``pattern``""" + return self.execute_command('CONFIG GET', + pattern, + target_nodes=target_nodes) + + def config_resetstat(self, target_nodes=None): + """Reset runtime statistics""" + return self.execute_command('CONFIG RESETSTAT', + target_nodes=target_nodes) + + def config_rewrite(self, target_nodes=None): + """ + Rewrite config file with the minimal change to reflect running config. + """ + return self.execute_command('CONFIG REWRITE', + target_nodes=target_nodes) + + def config_set(self, name, value, target_nodes=None): + "Set config item ``name`` with ``value``" + return self.execute_command('CONFIG SET', + name, + value, + target_nodes=target_nodes) + + def dbsize(self, target_nodes=None): + """ + Sums the number of keys in the target nodes' DB. + If no target nodes are specified, send to the entire cluster and sum + the results. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('DBSIZE', + target_nodes=target_nodes) + + def debug_object(self, key): + raise NotImplementedError( + "DEBUG OBJECT is intentionally not implemented in the client." + ) + + def debug_segfault(self): + raise NotImplementedError( + "DEBUG SEGFAULT is intentionally not implemented in the client." + ) + + def echo(self, value, target_nodes): + """Echo the string back from the server""" + return self.execute_command('ECHO', value, + target_nodes=target_nodes) + + def flushall(self, asynchronous=False, target_nodes=None): + """ + Delete all keys in the database on all hosts. + In cluster mode this method is the same as flushdb + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if asynchronous: + args.append(b'ASYNC') + return self.execute_command('FLUSHALL', + *args, + target_nodes=target_nodes) + + def flushdb(self, asynchronous=False, target_nodes=None): + """ + Delete all keys in the database. + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if asynchronous: + args.append(b'ASYNC') + return self.execute_command('FLUSHDB', + *args, + target_nodes=target_nodes) + + def info(self, section=None, target_nodes=None): + """ + Returns a dictionary containing information about the Redis server + + The ``section`` option can be used to select a specific section + of information + + The section option is not supported by older versions of Redis Server, + and will generate ResponseError + """ + if section is None: + return self.execute_command('INFO', + target_nodes=target_nodes) + else: + return self.execute_command('INFO', + section, + target_nodes=target_nodes) + + def lastsave(self, target_nodes=None): + """ + Return a Python datetime object representing the last time the + Redis database was saved to disk + """ + return self.execute_command('LASTSAVE', + target_nodes=target_nodes) + + def memory_doctor(self): + raise NotImplementedError( + "MEMORY DOCTOR is intentionally not implemented in the client." + ) + + def memory_help(self): + raise NotImplementedError( + "MEMORY HELP is intentionally not implemented in the client." + ) + + def memory_malloc_stats(self, target_nodes=None): + """Return an internal statistics report from the memory allocator.""" + return self.execute_command('MEMORY MALLOC-STATS', + target_nodes=target_nodes) + + def memory_purge(self, target_nodes=None): + """Attempts to purge dirty pages for reclamation by allocator""" + return self.execute_command('MEMORY PURGE', + target_nodes=target_nodes) + + def memory_stats(self, target_nodes=None): + """Return a dictionary of memory stats""" + return self.execute_command('MEMORY STATS', + target_nodes=target_nodes) + + def memory_usage(self, key, samples=None): + """ + Return the total memory usage for key, its value and associated + administrative overheads. + + For nested data structures, ``samples`` is the number of elements to + sample. If left unspecified, the server's default is 5. Use 0 to sample + all elements. + """ + args = [] + if isinstance(samples, int): + args.extend([b'SAMPLES', samples]) + return self.execute_command('MEMORY USAGE', key, *args) + + def migrate(self, host, source_node, port, keys, destination_db, timeout, + copy=False, replace=False, auth=None): + """ + Migrate 1 or more keys from the source_node Redis server to a different + server specified by the ``host``, ``port`` and ``destination_db``. + + The ``timeout``, specified in milliseconds, indicates the maximum + time the connection between the two servers can be idle before the + command is interrupted. + + If ``copy`` is True, the specified ``keys`` are NOT deleted from + the source server. + + If ``replace`` is True, this operation will overwrite the keys + on the destination server if they exist. + + If ``auth`` is specified, authenticate to the destination server with + the password provided. + """ + keys = list_or_args(keys, []) + if not keys: + raise DataError('MIGRATE requires at least one key') + pieces = [] + if copy: + pieces.append(b'COPY') + if replace: + pieces.append(b'REPLACE') + if auth: + pieces.append(b'AUTH') + pieces.append(auth) + pieces.append(b'KEYS') + pieces.extend(keys) + return self.execute_command('MIGRATE', host, port, '', destination_db, + timeout, *pieces, + target_nodes=source_node) + + def object(self, infotype, key): + """Return the encoding, idletime, or refcount about the key""" + return self.execute_command('OBJECT', infotype, key, infotype=infotype) + + def ping(self, target_nodes=None): + """ + Ping the cluster's servers. + If no target nodes are specified, sent to all nodes and returns True if + the ping was successful across all nodes. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('PING', + target_nodes=target_nodes) + + def save(self): + """ + Tell the Redis server to save its data to disk, + blocking until the save is complete + """ + return self.execute_command('SAVE') + + def shutdown(self, save=False, nosave=False): + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. If the "save" option is set, + a data flush will be attempted even if there is no persistence + configured. If the "nosave" option is set, no data flush will be + attempted. The "save" and "nosave" options cannot both be set. + """ + if save and nosave: + raise DataError('SHUTDOWN save and nosave cannot both be set') + args = ['SHUTDOWN'] + if save: + args.append('SAVE') + if nosave: + args.append('NOSAVE') + try: + self.execute_command(*args) + except ConnectionError: + # a ConnectionError here is expected + return + raise RedisError("SHUTDOWN seems to have failed.") + + def slowlog_get(self, num=None, target_nodes=None): + """ + Get the entries from the slowlog. If ``num`` is specified, get the + most recent ``num`` items. + """ + args = ['SLOWLOG GET'] + if num is not None: + args.append(num) + + return self.execute_command(*args, + target_nodes=target_nodes) + + def slowlog_len(self, target_nodes=None): + "Get the number of items in the slowlog" + return self.execute_command('SLOWLOG LEN', + target_nodes=target_nodes) + + def slowlog_reset(self, target_nodes=None): + "Remove all items in the slowlog" + return self.execute_command('SLOWLOG RESET', + target_nodes=target_nodes) + + def time(self, target_nodes=None): + """ + Returns the server time as a 2-item tuple of ints: + (seconds since epoch, microseconds into this second). + """ + return self.execute_command('TIME', target_nodes=target_nodes) + + def wait(self, num_replicas, timeout, target_nodes=None): + """ + Redis synchronous replication + That returns the number of replicas that processed the query when + we finally have at least ``num_replicas``, or when the ``timeout`` was + reached. + + In cluster mode the WAIT command will be sent to all primaries + and the result will be summed up + """ + return self.execute_command('WAIT', num_replicas, + timeout, + target_nodes=target_nodes) + + +class ClusterCommands(ClusterManagementCommands, ClusterMultiKeyCommands, + DataAccessCommands, PubSubCommands): + def cluster_addslots(self, target_node, *slots): + """ + Assign new hash slots to receiving node. Sends to specified node. + + :target_node: 'ClusterNode' + The node to execute the command on + """ + return self.execute_command('CLUSTER ADDSLOTS', *slots, + target_nodes=target_node) + + def cluster_countkeysinslot(self, slot_id): + """ + Return the number of local keys in the specified hash slot + Send to node based on specified slot_id + """ + return self.execute_command('CLUSTER COUNTKEYSINSLOT', slot_id) + + def cluster_count_failure_report(self, node_id): + """ + Return the number of failure reports active for a given node + Sends to a random node + """ + return self.execute_command('CLUSTER COUNT-FAILURE-REPORTS', node_id) + + def cluster_delslots(self, *slots): + """ + Set hash slots as unbound in the cluster. + It determines by it self what node the slot is in and sends it there + + Returns a list of the results for each processed slot. + """ + return [ + self.execute_command('CLUSTER DELSLOTS', slot) + for slot in slots + ] + + def cluster_failover(self, target_node, option=None): + """ + Forces a slave to perform a manual failover of its master + Sends to specified node + + :target_node: 'ClusterNode' + The node to execute the command on + """ + if option: + if option.upper() not in ['FORCE', 'TAKEOVER']: + raise RedisError( + 'Invalid option for CLUSTER FAILOVER command: {0}'.format( + option)) + else: + return self.execute_command('CLUSTER FAILOVER', option, + target_nodes=target_node) + else: + return self.execute_command('CLUSTER FAILOVER', + target_nodes=target_node) + + def cluster_info(self, target_node=None): + """ + Provides info about Redis Cluster node state. + The command will be sent to a random node in the cluster if no target + node is specified. + + :target_node: 'ClusterNode' + The node to execute the command on + """ + return self.execute_command('CLUSTER INFO', target_nodes=target_node) + + def cluster_keyslot(self, key): + """ + Returns the hash slot of the specified key + Sends to random node in the cluster + """ + return self.execute_command('CLUSTER KEYSLOT', key) + + def cluster_meet(self, target_nodes, host, port): + """ + Force a node cluster to handshake with another node. + Sends to specified node. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER MEET', host, port, + target_nodes=target_nodes) + + def cluster_nodes(self): + """ + Force a node cluster to handshake with another node + + Sends to random node in the cluster + """ + return self.execute_command('CLUSTER NODES') + + def cluster_replicate(self, target_nodes, node_id): + """ + Reconfigure a node as a slave of the specified master node + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER REPLICATE', node_id, + target_nodes=target_nodes) + + def cluster_reset(self, target_nodes, soft=True): + """ + Reset a Redis Cluster node + + If 'soft' is True then it will send 'SOFT' argument + If 'soft' is False then it will send 'HARD' argument + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER RESET', + b'SOFT' if soft else b'HARD', + target_nodes=target_nodes) + + def cluster_save_config(self, target_nodes): + """ + Forces the node to save cluster state on disk + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER SAVECONFIG', + target_nodes=target_nodes) + + def cluster_get_keys_in_slot(self, slot, num_keys): + """ + Returns the number of keys in the specified cluster slot + """ + return self.execute_command('CLUSTER GETKEYSINSLOT', slot, num_keys) + + def cluster_set_config_epoch(self, target_nodes, epoch): + """ + Set the configuration epoch in a new node + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + return self.execute_command('CLUSTER SET-CONFIG-EPOCH', epoch, + target_nodes=target_nodes) + + def cluster_setslot(self, target_node, node_id, slot_id, state): + """ + Bind an hash slot to a specific node + + :target_node: 'ClusterNode' + The node to execute the command on + """ + if state.upper() in ('IMPORTING', 'NODE', 'MIGRATING'): + return self.execute_command('CLUSTER SETSLOT', slot_id, state, + node_id, target_nodes=target_node) + elif state.upper() == 'STABLE': + raise RedisError('For "stable" state please use ' + 'cluster_setslot_stable') + else: + raise RedisError('Invalid slot state: {0}'.format(state)) + + def cluster_setslot_stable(self, slot_id): + """ + Clears migrating / importing state from the slot. + It determines by it self what node the slot is in and sends it there. + """ + return self.execute_command('CLUSTER SETSLOT', slot_id, 'STABLE') + + def cluster_replicas(self, node_id): + """ + Provides a list of replica nodes replicating from the specified primary + target node. + Sends to random node in the cluster. + """ + return self.execute_command('CLUSTER REPLICAS', node_id) + + def cluster_slots(self): + """ + Get array of Cluster slot to node mappings + + Sends to random node in the cluster + """ + return self.execute_command('CLUSTER SLOTS') + + def readonly(self, target_nodes=None): + """ + Enables read queries. + The command will be sent to all replica nodes if target_nodes is not + specified. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + self.read_from_replicas = True + return self.execute_command('READONLY', target_nodes=target_nodes) + + def readwrite(self, target_nodes=None): + """ + Disables read queries. + The command will be sent to all replica nodes if target_nodes is not + specified. + + :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' + The node/s to execute the command on + """ + # Reset read from replicas flag + self.read_from_replicas = False + return self.execute_command('READWRITE', target_nodes=target_nodes) diff --git a/redis/commands/core.py b/redis/commands/core.py index 6512b45a42..c6f589d21e 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -12,14 +12,7 @@ ) -class CoreCommands: - """ - A class containing all of the implemented redis commands. This class is - to be used as a mixin. - """ - - # SERVER INFORMATION - +class AclCommands: # ACL methods def acl_cat(self, category=None): """ @@ -267,6 +260,8 @@ def acl_whoami(self): "Get the username for the current connection" return self.execute_command('ACL WHOAMI') + +class ManagementCommands: def bgrewriteaof(self): "Tell the Redis server to rewrite the AOF file from data in memory." return self.execute_command('BGREWRITEAOF') @@ -431,6 +426,14 @@ def client_unpause(self): """ return self.execute_command('CLIENT UNPAUSE') + def command_info(self): + raise NotImplementedError( + "COMMAND INFO is intentionally not implemented in the client." + ) + + def command_count(self): + return self.execute_command('COMMAND COUNT') + def readwrite(self): """ Disables read queries for a connection to a Redis Cluster slave node. @@ -461,6 +464,9 @@ def config_rewrite(self): """ return self.execute_command('CONFIG REWRITE') + def cluster(self, cluster_arg, *args): + return self.execute_command('CLUSTER %s' % cluster_arg.upper(), *args) + def dbsize(self): """Returns the number of keys in the current database""" return self.execute_command('DBSIZE') @@ -624,6 +630,16 @@ def quit(self): """ return self.execute_command('QUIT') + def replicaof(self, *args): + """ + Update the replication settings of a redis replica, on the fly. + Examples of valid arguments include: + NO ONE (set no replication) + host port (set to the host and port of a redis server) + see: https://redis.io/commands/replicaof + """ + return self.execute_command('REPLICAOF', *args) + def save(self): """ Tell the Redis server to save its data to disk, @@ -698,6 +714,8 @@ def wait(self, num_replicas, timeout): """ return self.execute_command('WAIT', num_replicas, timeout) + +class BasicKeyCommands: # BASIC KEY COMMANDS def append(self, key, value): """ @@ -1324,6 +1342,8 @@ def unlink(self, *names): "Unlink one or more keys specified by ``names``" return self.execute_command('UNLINK', *names) + +class ListCommands: # LIST COMMANDS def blpop(self, keys, timeout=0): """ @@ -1576,6 +1596,8 @@ def sort(self, name, start=None, num=None, by=None, get=None, options = {'groups': len(get) if groups else None} return self.execute_command('SORT', *pieces, **options) + +class ScanCommands: # SCAN COMMANDS def scan(self, cursor=0, match=None, count=None, _type=None): """ @@ -1723,6 +1745,8 @@ def zscan_iter(self, name, match=None, count=None, score_cast_func=score_cast_func) yield from data + +class SetCommands: # SET COMMANDS def sadd(self, name, *values): """Add ``value(s)`` to set ``name``""" @@ -1805,6 +1829,8 @@ def sunionstore(self, dest, keys, *args): args = list_or_args(keys, args) return self.execute_command('SUNIONSTORE', dest, *args) + +class StreamsCommands: # STREAMS COMMANDS def xack(self, name, groupname, *ids): """ @@ -2243,6 +2269,8 @@ def xtrim(self, name, maxlen=None, approximate=True, minid=None, return self.execute_command('XTRIM', name, *pieces) + +class SortedSetCommands: # SORTED SET COMMANDS def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False, gt=None, lt=None): @@ -2721,6 +2749,8 @@ def _zaggregate(self, command, dest, keys, aggregate=None, pieces.append(b'WITHSCORES') return self.execute_command(*pieces, **options) + +class HyperLogLogCommands: # HYPERLOGLOG COMMANDS def pfadd(self, name, *values): "Adds the specified elements to the specified HyperLogLog." @@ -2737,6 +2767,8 @@ def pfmerge(self, dest, *sources): "Merge N different HyperLogLogs into a single one." return self.execute_command('PFMERGE', dest, *sources) + +class HashCommands: # HASH COMMANDS def hdel(self, name, *keys): "Delete ``keys`` from hash ``name``" @@ -2831,6 +2863,9 @@ def hstrlen(self, name, key): """ return self.execute_command('HSTRLEN', name, key) + +class PubSubCommands: + # PUBSUB COMMANDS def publish(self, channel, message): """ Publish ``message`` on ``channel``. @@ -2857,19 +2892,9 @@ def pubsub_numsub(self, *args): """ return self.execute_command('PUBSUB NUMSUB', *args) - def cluster(self, cluster_arg, *args): - return self.execute_command('CLUSTER %s' % cluster_arg.upper(), *args) - - def replicaof(self, *args): - """ - Update the replication settings of a redis replica, on the fly. - Examples of valid arguments include: - NO ONE (set no replication) - host port (set to the host and port of a redis server) - see: https://redis.io/commands/replicaof - """ - return self.execute_command('REPLICAOF', *args) +class ScriptCommands: + # SCRIPT COMMANDS def eval(self, script, numkeys, *keys_and_args): """ Execute the Lua ``script``, specifying the ``numkeys`` the script @@ -2941,6 +2966,8 @@ def register_script(self, script): """ return Script(self, script) + +class GeoCommands: # GEO COMMANDS def geoadd(self, name, values, nx=False, xx=False, ch=False): """ @@ -3235,6 +3262,8 @@ def _geosearchgeneric(self, command, *args, **kwargs): return self.execute_command(command, *pieces, **kwargs) + +class ModuleCommands: # MODULE COMMANDS def module_load(self, path, *args): """ @@ -3258,14 +3287,6 @@ def module_list(self): """ return self.execute_command('MODULE LIST') - def command_info(self): - raise NotImplementedError( - "COMMAND INFO is intentionally not implemented in the client." - ) - - def command_count(self): - return self.execute_command('COMMAND COUNT') - class Script: "An executable Lua script object returned by ``register_script``" @@ -3397,3 +3418,22 @@ def execute(self): command = self.command self.reset() return self.client.execute_command(*command) + + +class DataAccessCommands(BasicKeyCommands, ListCommands, + ScanCommands, SetCommands, StreamsCommands, + SortedSetCommands, + HyperLogLogCommands, HashCommands, GeoCommands, + ): + """ + A class containing all of the implemented data access redis commands. + This class is to be used as a mixin. + """ + + +class CoreCommands(AclCommands, DataAccessCommands, ManagementCommands, + ModuleCommands, PubSubCommands, ScriptCommands): + """ + A class containing all of the implemented redis commands. This class is + to be used as a mixin. + """ diff --git a/redis/commands/parser.py b/redis/commands/parser.py new file mode 100644 index 0000000000..22478ed2ed --- /dev/null +++ b/redis/commands/parser.py @@ -0,0 +1,108 @@ +from redis.exceptions import ( + RedisError, + ResponseError +) +from redis.utils import str_if_bytes + + +class CommandsParser: + """ + Parses Redis commands to get command keys. + COMMAND output is used to determine key locations. + Commands that do not have a predefined key location are flagged with + 'movablekeys', and these commands' keys are determined by the command + 'COMMAND GETKEYS'. + """ + def __init__(self, redis_connection): + self.initialized = False + self.commands = {} + self.initialize(redis_connection) + + def initialize(self, r): + self.commands = r.execute_command("COMMAND") + + # As soon as this PR is merged into Redis, we should reimplement + # our logic to use COMMAND INFO changes to determine the key positions + # https://github.com/redis/redis/pull/8324 + def get_keys(self, redis_conn, *args): + """ + Get the keys from the passed command + """ + if len(args) < 2: + # The command has no keys in it + return None + + cmd_name = args[0].lower() + cmd_name_split = cmd_name.split() + if len(cmd_name_split) > 1: + # we need to take only the main command, e.g. 'memory' for + # 'memory usage' + cmd_name = cmd_name_split[0] + if cmd_name not in self.commands: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + self.initialize(redis_conn) + if cmd_name not in self.commands: + raise RedisError("{0} command doesn't exist in Redis commands". + format(cmd_name.upper())) + + command = self.commands.get(cmd_name) + if 'movablekeys' in command['flags']: + keys = self._get_moveable_keys(redis_conn, *args) + elif 'pubsub' in command['flags']: + keys = self._get_pubsub_keys(*args) + else: + if command['step_count'] == 0 and command['first_key_pos'] == 0 \ + and command['last_key_pos'] == 0: + # The command doesn't have keys in it + return None + last_key_pos = command['last_key_pos'] + if last_key_pos == -1: + last_key_pos = len(args) - 1 + keys_pos = list(range(command['first_key_pos'], last_key_pos + 1, + command['step_count'])) + keys = [args[pos] for pos in keys_pos] + + return keys + + def _get_moveable_keys(self, redis_conn, *args): + try: + pieces = [] + cmd_name = args[0] + for arg in cmd_name.split(): + # The command name should be splitted into separate arguments, + # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] + pieces.append(arg) + pieces += args[1:] + keys = redis_conn.execute_command('COMMAND GETKEYS', *pieces) + except ResponseError as e: + message = e.__str__() + if 'Invalid arguments' in message or \ + 'The command has no key arguments' in message: + return None + else: + raise e + return keys + + def _get_pubsub_keys(self, *args): + """ + Get the keys from pubsub command. + Although PubSub commands have predetermined key locations, they are not + supported in the 'COMMAND's output, so the key positions are hardcoded + in this method + """ + if len(args) < 2: + # The command has no keys in it + return None + args = [str_if_bytes(arg) for arg in args] + command = args[0].upper() + if command in ['PUBLISH', 'PUBSUB CHANNELS']: + # format example: + # PUBLISH channel message + keys = [args[1]] + elif command in ['SUBSCRIBE', 'PSUBSCRIBE', 'UNSUBSCRIBE', + 'PUNSUBSCRIBE', 'PUBSUB NUMSUB']: + keys = list(args[1:]) + else: + keys = None + return keys diff --git a/redis/connection.py b/redis/connection.py index c99c550ecd..e1ad6ea7f2 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -11,6 +11,7 @@ import threading import warnings +from redis.backoff import NoBackoff from redis.exceptions import ( AuthenticationError, AuthenticationWrongNumberOfArgsError, @@ -28,9 +29,9 @@ TimeoutError, ModuleError, ) -from redis.utils import HIREDIS_AVAILABLE, str_if_bytes -from redis.backoff import NoBackoff + from redis.retry import Retry +from redis.utils import HIREDIS_AVAILABLE, str_if_bytes try: import ssl @@ -506,7 +507,7 @@ def __init__(self, host='localhost', port=6379, db=0, password=None, encoding_errors='strict', decode_responses=False, parser_class=DefaultParser, socket_read_size=65536, health_check_interval=0, client_name=None, username=None, - retry=None): + retry=None, redis_connect_func=None): """ Initialize a new Connection. To specify a retry policy, first set `retry_on_timeout` to `True` @@ -536,8 +537,10 @@ def __init__(self, host='localhost', port=6379, db=0, password=None, self.health_check_interval = health_check_interval self.next_health_check = 0 self.encoder = Encoder(encoding, encoding_errors, decode_responses) + self.redis_connect_func = redis_connect_func self._sock = None - self._parser = parser_class(socket_read_size=socket_read_size) + self._socket_read_size = socket_read_size + self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 @@ -567,6 +570,9 @@ def register_connect_callback(self, callback): def clear_connect_callbacks(self): self._connect_callbacks = [] + def set_parser(self, parser_class): + self._parser = parser_class(socket_read_size=self._socket_read_size) + def connect(self): "Connects to the Redis server if not already connected" if self._sock: @@ -580,7 +586,12 @@ def connect(self): self._sock = sock try: - self.on_connect() + if self.redis_connect_func is None: + # Use the default on_connect function + self.on_connect() + else: + # Use the passed function redis_connect_func + self.redis_connect_func(self) except RedisError: # clean up after any error in on_connect self.disconnect() @@ -910,7 +921,8 @@ def __init__(self, path='', db=0, username=None, password=None, self.next_health_check = 0 self.encoder = Encoder(encoding, encoding_errors, decode_responses) self._sock = None - self._parser = parser_class(socket_read_size=socket_read_size) + self._socket_read_size = socket_read_size + self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 diff --git a/redis/crc.py b/redis/crc.py new file mode 100644 index 0000000000..a4dfdf69f5 --- /dev/null +++ b/redis/crc.py @@ -0,0 +1,28 @@ +from binascii import crc_hqx + +# Redis Cluster's key space is divided into 16384 slots. +# For more information see: https://github.com/redis/redis/issues/2576 +REDIS_CLUSTER_HASH_SLOTS = 16384 + +__all__ = [ + "crc16", + "key_slot", + "REDIS_CLUSTER_HASH_SLOTS" +] + + +def crc16(data): + return crc_hqx(data, 0) + + +def key_slot(key, bucket=REDIS_CLUSTER_HASH_SLOTS): + """Calculate key slot for a given key. + :param key - bytes + :param bucket - int + """ + start = key.find(b"{") + if start > -1: + end = key.find(b"}", start + 1) + if end > -1 and end != start + 1: + key = key[start + 1: end] + return crc16(key) % bucket diff --git a/redis/exceptions.py b/redis/exceptions.py index 91eb3c7257..5ea7fe9c30 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -84,3 +84,67 @@ class AuthenticationWrongNumberOfArgsError(ResponseError): were sent to the AUTH command """ pass + + +class RedisClusterException(Exception): + pass + + +class ClusterError(RedisError): + pass + + +class ClusterDownError(ClusterError, ResponseError): + + def __init__(self, resp): + self.args = (resp,) + self.message = resp + + +class AskError(ResponseError): + """ + src node: MIGRATING to dst node + get > ASK error + ask dst node > ASKING command + dst node: IMPORTING from src node + asking command only affects next command + any op will be allowed after asking command + """ + + def __init__(self, resp): + """should only redirect to master node""" + self.args = (resp,) + self.message = resp + slot_id, new_node = resp.split(' ') + host, port = new_node.rsplit(':', 1) + self.slot_id = int(slot_id) + self.node_addr = self.host, self.port = host, int(port) + + +class TryAgainError(ResponseError): + + def __init__(self, *args, **kwargs): + pass + + +class ClusterCrossSlotError(ResponseError): + message = "Keys in request don't hash to the same slot" + + +class MovedError(AskError): + pass + + +class MasterDownError(ClusterDownError): + pass + + +class SlotNotCoveredError(RedisClusterException): + """ + This error only happens in the case where the connection pool will try to + fetch what node that is covered by a given slot. + + If this error is raised the client should drop the current node layout and + attempt to reconnect and refresh the node layout again + """ + pass diff --git a/redis/utils.py b/redis/utils.py index 26fb002b89..0e78cc5f3b 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -36,3 +36,39 @@ def str_if_bytes(value): def safe_str(value): return str(str_if_bytes(value)) + + +def dict_merge(*dicts): + """ + Merge all provided dicts into 1 dict. + *dicts : `dict` + dictionaries to merge + """ + merged = {} + + for d in dicts: + merged.update(d) + + return merged + + +def list_keys_to_dict(key_list, callback): + return dict.fromkeys(key_list, callback) + + +def merge_result(command, res): + """ + Merge all items in `res` into a list. + + This command is used when sending a command to multiple nodes + and they result from each node should be merged into a single list. + + res : 'dict' + """ + result = set() + + for v in res.values(): + for value in v: + result.add(value) + + return list(result) diff --git a/tasks.py b/tasks.py index aa965c6902..44b652908d 100644 --- a/tasks.py +++ b/tasks.py @@ -40,7 +40,10 @@ def tests(c): """Run the redis-py test suite against the current python, with and without hiredis. """ + print("Starting Redis tests") run("tox -e plain -e hiredis") + print("Starting RedisCluster tests") + run("tox -e plain -e hiredis -- --redis-url=redis://localhost:16379/0") @task diff --git a/tests/conftest.py b/tests/conftest.py index 47188df07f..df809bf81d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,17 +3,19 @@ import pytest import random import redis +import time from distutils.version import LooseVersion from redis.connection import parse_url +from redis.exceptions import RedisClusterException from unittest.mock import Mock from urllib.parse import urlparse - REDIS_INFO = {} default_redis_url = "redis://localhost:6379/9" default_redismod_url = "redis://localhost:36379/9" default_redismod_url = "redis://localhost:36379" +default_cluster_nodes = 6 def pytest_addoption(parser): @@ -28,6 +30,12 @@ def pytest_addoption(parser): " with loaded modules," " defaults to `%(default)s`") + parser.addoption('--cluster-nodes', default=default_cluster_nodes, + action="store", + help="The number of cluster nodes that need to be " + "available before the test can start," + " defaults to `%(default)s`") + def _get_info(redis_url): client = redis.Redis.from_url(redis_url) @@ -41,14 +49,37 @@ def pytest_sessionstart(session): info = _get_info(redis_url) version = info["redis_version"] arch_bits = info["arch_bits"] + cluster_enabled = info["cluster_enabled"] REDIS_INFO["version"] = version REDIS_INFO["arch_bits"] = arch_bits + REDIS_INFO["cluster_enabled"] = cluster_enabled # module info redismod_url = session.config.getoption("--redismod-url") info = _get_info(redismod_url) REDIS_INFO["modules"] = info["modules"] + if cluster_enabled: + cluster_nodes = session.config.getoption("--cluster-nodes") + wait_for_cluster_creation(redis_url, cluster_nodes) + + +def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=20): + now = time.time() + timeout = now + timeout + print("Waiting for {0} cluster nodes to become available". + format(cluster_nodes)) + while now < timeout: + try: + client = redis.RedisCluster.from_url(redis_url) + if len(client.get_nodes()) == cluster_nodes: + print("All nodes are available!") + break + except RedisClusterException: + pass + time.sleep(1) + now = time.time() + def skip_if_server_version_lt(min_version): redis_version = REDIS_INFO["version"] @@ -86,6 +117,17 @@ def skip_ifmodversion_lt(min_version: str, module_name: str): raise AttributeError("No redis module named {}".format(module_name)) +def skip_if_cluster_mode(): + return pytest.mark.skipif(REDIS_INFO["cluster_enabled"], + reason="This test isn't supported with cluster " + "mode") + + +def skip_if_not_cluster_mode(): + return pytest.mark.skipif(not REDIS_INFO["cluster_enabled"], + reason="Cluster-mode is required for this test") + + def _get_client(cls, request, single_connection_client=True, flushdb=True, from_url=None, **kwargs): @@ -100,10 +142,14 @@ def _get_client(cls, request, single_connection_client=True, flushdb=True, redis_url = request.config.getoption("--redis-url") else: redis_url = from_url - url_options = parse_url(redis_url) - url_options.update(kwargs) - pool = redis.ConnectionPool(**url_options) - client = cls(connection_pool=pool) + if REDIS_INFO["cluster_enabled"]: + client = redis.RedisCluster.from_url(redis_url, **kwargs) + single_connection_client = False + else: + url_options = parse_url(redis_url) + url_options.update(kwargs) + pool = redis.ConnectionPool(**url_options) + client = cls(connection_pool=pool) if single_connection_client: client = client.client() if request: @@ -116,7 +162,10 @@ def teardown(): # just manually retry the flushdb client.flushdb() client.close() - client.connection_pool.disconnect() + if REDIS_INFO["cluster_enabled"]: + client.disconnect_connection_pools() + else: + client.connection_pool.disconnect() request.addfinalizer(teardown) return client @@ -220,6 +269,13 @@ def master_host(request): yield parts.hostname +@pytest.fixture(scope="session") +def master_port(request): + url = request.config.getoption("--redis-url") + parts = urlparse(url) + yield parts.port + + def wait_for_command(client, monitor, command): # issue a command with a key name that's local to this process. # if we find a command with our key before the command we're waiting diff --git a/tests/test_cluster.py b/tests/test_cluster.py new file mode 100644 index 0000000000..fa78c04710 --- /dev/null +++ b/tests/test_cluster.py @@ -0,0 +1,1403 @@ +import pytest +import datetime +import warnings + +from time import sleep +from unittest.mock import call, patch, DEFAULT, Mock +from redis import Redis +from redis.cluster import get_node_name, ClusterNode, \ + RedisCluster, NodesManager, PRIMARY, REDIS_CLUSTER_HASH_SLOTS, REPLICA +from redis.commands import CommandsParser +from redis.connection import Connection +from redis.utils import str_if_bytes +from redis.exceptions import ( + AskError, + ClusterDownError, + MovedError, + RedisClusterException, + RedisError +) + +from redis.crc import key_slot +from .conftest import ( + skip_if_not_cluster_mode, + _get_client, + skip_if_server_version_lt +) + +default_host = "127.0.0.1" +default_port = 7000 +default_cluster_slots = [ + [ + 0, 8191, + ['127.0.0.1', 7000, 'node_0'], + ['127.0.0.1', 7003, 'node_3'], + ], + [ + 8192, 16383, + ['127.0.0.1', 7001, 'node_1'], + ['127.0.0.1', 7002, 'node_2'] + ] +] + + +@pytest.fixture() +def slowlog(request, r): + """ + Set the slowlog threshold to 0, and the + max length to 128. This will force every + command into the slowlog and allow us + to test it + """ + # Save old values + current_config = r.config_get( + target_nodes=r.get_primaries()[0]) + old_slower_than_value = current_config['slowlog-log-slower-than'] + old_max_legnth_value = current_config['slowlog-max-len'] + + # Function to restore the old values + def cleanup(): + r.config_set('slowlog-log-slower-than', old_slower_than_value) + r.config_set('slowlog-max-len', old_max_legnth_value) + request.addfinalizer(cleanup) + + # Set the new values + r.config_set('slowlog-log-slower-than', 0) + r.config_set('slowlog-max-len', 128) + + +def get_mocked_redis_client(func=None, *args, **kwargs): + """ + Return a stable RedisCluster object that have deterministic + nodes and slots setup to remove the problem of different IP addresses + on different installations and machines. + """ + cluster_slots = kwargs.pop('cluster_slots', default_cluster_slots) + coverage_res = kwargs.pop('coverage_result', 'yes') + with patch.object(Redis, 'execute_command') as execute_command_mock: + def execute_command(*_args, **_kwargs): + if _args[0] == 'CLUSTER SLOTS': + mock_cluster_slots = cluster_slots + return mock_cluster_slots + elif _args[1] == 'cluster-require-full-coverage': + return {'cluster-require-full-coverage': coverage_res} + elif func is not None: + return func(*args, **kwargs) + else: + return execute_command_mock(*_args, **_kwargs) + + execute_command_mock.side_effect = execute_command + + with patch.object(CommandsParser, 'initialize', + autospec=True) as cmd_parser_initialize: + + def cmd_init_mock(self, r): + self.commands = {'get': {'name': 'get', 'arity': 2, + 'flags': ['readonly', + 'fast'], + 'first_key_pos': 1, + 'last_key_pos': 1, + 'step_count': 1}} + + cmd_parser_initialize.side_effect = cmd_init_mock + + return RedisCluster(*args, **kwargs) + + +def mock_node_resp(node, response): + connection = Mock() + connection.read_response.return_value = response + node.redis_connection.connection = connection + return node + + +def mock_all_nodes_resp(rc, response): + for node in rc.get_nodes(): + mock_node_resp(node, response) + return rc + + +def find_node_ip_based_on_port(cluster_client, port): + for node in cluster_client.get_nodes(): + if node.port == port: + return node.host + + +def moved_redirection_helper(request, failover=False): + """ + Test that the client handles MOVED response after a failover. + Redirection after a failover means that the redirection address is of a + replica that was promoted to a primary. + + At first call it should return a MOVED ResponseError that will point + the client to the next server it should talk to. + + Verify that: + 1. it tries to talk to the redirected node + 2. it updates the slot's primary to the redirected node + + For a failover, also verify: + 3. the redirected node's server type updated to 'primary' + 4. the server type of the previous slot owner updated to 'replica' + """ + rc = _get_client(RedisCluster, request, flushdb=False) + slot = 12182 + redirect_node = None + # Get the current primary that holds this slot + prev_primary = rc.nodes_manager.get_node_from_slot(slot) + if failover: + if len(rc.nodes_manager.slots_cache[slot]) < 2: + warnings.warn("Skipping this test since it requires to have a " + "replica") + return + redirect_node = rc.nodes_manager.slots_cache[slot][1] + else: + # Use one of the primaries to be the redirected node + redirect_node = rc.get_primaries()[0] + r_host = redirect_node.host + r_port = redirect_node.port + with patch.object(Redis, 'parse_response') as parse_response: + def moved_redirect_effect(connection, *args, **options): + def ok_response(connection, *args, **options): + assert connection.host == r_host + assert connection.port == r_port + + return "MOCK_OK" + + parse_response.side_effect = ok_response + raise MovedError("{0} {1}:{2}".format(slot, r_host, r_port)) + + parse_response.side_effect = moved_redirect_effect + assert rc.execute_command("SET", "foo", "bar") == "MOCK_OK" + slot_primary = rc.nodes_manager.slots_cache[slot][0] + assert slot_primary == redirect_node + if failover: + assert rc.get_node(host=r_host, port=r_port).server_type == PRIMARY + assert prev_primary.server_type == REPLICA + + +@skip_if_not_cluster_mode() +class TestRedisClusterObj: + def test_host_port_startup_node(self): + """ + Test that it is possible to use host & port arguments as startup node + args + """ + cluster = get_mocked_redis_client(host=default_host, port=default_port) + assert cluster.get_node(host=default_host, + port=default_port) is not None + + def test_startup_nodes(self): + """ + Test that it is possible to use startup_nodes + argument to init the cluster + """ + port_1 = 7000 + port_2 = 7001 + startup_nodes = [ClusterNode(default_host, port_1), + ClusterNode(default_host, port_2)] + cluster = get_mocked_redis_client(startup_nodes=startup_nodes) + assert cluster.get_node(host=default_host, port=port_1) is not None \ + and cluster.get_node(host=default_host, port=port_2) is not None + + def test_empty_startup_nodes(self): + """ + Test that exception is raised when empty providing empty startup_nodes + """ + with pytest.raises(RedisClusterException) as ex: + RedisCluster(startup_nodes=[]) + + assert str(ex.value).startswith( + "RedisCluster requires at least one node to discover the " + "cluster"), str_if_bytes(ex.value) + + def test_from_url(self, r): + redis_url = "redis://{0}:{1}/0".format(default_host, default_port) + with patch.object(RedisCluster, 'from_url') as from_url: + def from_url_mocked(_url, **_kwargs): + return get_mocked_redis_client(url=_url, **_kwargs) + + from_url.side_effect = from_url_mocked + cluster = RedisCluster.from_url(redis_url) + assert cluster.get_node(host=default_host, + port=default_port) is not None + + def test_execute_command_errors(self, r): + """ + Test that if no key is provided then exception should be raised. + """ + with pytest.raises(RedisClusterException) as ex: + r.execute_command("GET") + assert str(ex.value).startswith("No way to dispatch this command to " + "Redis Cluster. Missing key.") + + def test_execute_command_node_flag_primaries(self, r): + """ + Test command execution with nodes flag PRIMARIES + """ + primaries = r.get_primaries() + replicas = r.get_replicas() + mock_all_nodes_resp(r, 'PONG') + assert r.ping(RedisCluster.PRIMARIES) is True + for primary in primaries: + conn = primary.redis_connection.connection + assert conn.read_response.called is True + for replica in replicas: + conn = replica.redis_connection.connection + assert conn.read_response.called is not True + + def test_execute_command_node_flag_replicas(self, r): + """ + Test command execution with nodes flag REPLICAS + """ + replicas = r.get_replicas() + if not replicas: + r = get_mocked_redis_client(default_host, default_port) + primaries = r.get_primaries() + mock_all_nodes_resp(r, 'PONG') + assert r.ping(RedisCluster.REPLICAS) is True + for replica in replicas: + conn = replica.redis_connection.connection + assert conn.read_response.called is True + for primary in primaries: + conn = primary.redis_connection.connection + assert conn.read_response.called is not True + + def test_execute_command_node_flag_all_nodes(self, r): + """ + Test command execution with nodes flag ALL_NODES + """ + mock_all_nodes_resp(r, 'PONG') + assert r.ping(RedisCluster.ALL_NODES) is True + for node in r.get_nodes(): + conn = node.redis_connection.connection + assert conn.read_response.called is True + + def test_execute_command_node_flag_random(self, r): + """ + Test command execution with nodes flag RANDOM + """ + mock_all_nodes_resp(r, 'PONG') + assert r.ping(RedisCluster.RANDOM) is True + called_count = 0 + for node in r.get_nodes(): + conn = node.redis_connection.connection + if conn.read_response.called is True: + called_count += 1 + assert called_count == 1 + + @pytest.mark.filterwarnings("ignore:AskError") + def test_ask_redirection(self, r): + """ + Test that the server handles ASK response. + + At first call it should return a ASK ResponseError that will point + the client to the next server it should talk to. + + Important thing to verify is that it tries to talk to the second node. + """ + redirect_node = r.get_nodes()[0] + with patch.object(Redis, 'parse_response') as parse_response: + def ask_redirect_effect(connection, *args, **options): + def ok_response(connection, *args, **options): + assert connection.host == redirect_node.host + assert connection.port == redirect_node.port + + return "MOCK_OK" + + parse_response.side_effect = ok_response + raise AskError("12182 {0}:{1}".format(redirect_node.host, + redirect_node.port)) + + parse_response.side_effect = ask_redirect_effect + + assert r.execute_command("SET", "foo", "bar") == "MOCK_OK" + + @pytest.mark.filterwarnings("ignore:MovedError") + def test_moved_redirection(self, request): + """ + Test that the client handles MOVED response. + """ + moved_redirection_helper(request, failover=False) + + @pytest.mark.filterwarnings("ignore:MovedError") + def test_moved_redirection_after_failover(self, request): + """ + Test that the client handles MOVED response after a failover. + """ + moved_redirection_helper(request, failover=True) + + @pytest.mark.filterwarnings("ignore:ClusterDownError") + def test_refresh_using_specific_nodes(self, request): + """ + Test making calls on specific nodes when the cluster has failed over to + another node + """ + node_7006 = ClusterNode(host=default_host, port=7006, + server_type=PRIMARY) + node_7007 = ClusterNode(host=default_host, port=7007, + server_type=PRIMARY) + with patch.object(Redis, 'parse_response') as parse_response: + with patch.object(NodesManager, 'initialize', autospec=True) as \ + initialize: + with patch.multiple(Connection, + send_command=DEFAULT, + connect=DEFAULT, + can_read=DEFAULT) as mocks: + # simulate 7006 as a failed node + def parse_response_mock(connection, command_name, + **options): + if connection.port == 7006: + parse_response.failed_calls += 1 + raise ClusterDownError( + 'CLUSTERDOWN The cluster is ' + 'down. Use CLUSTER INFO for ' + 'more information') + elif connection.port == 7007: + parse_response.successful_calls += 1 + + def initialize_mock(self): + # start with all slots mapped to 7006 + self.nodes_cache = {node_7006.name: node_7006} + self.slots_cache = {} + + for i in range(0, 16383): + self.slots_cache[i] = [node_7006] + + # After the first connection fails, a reinitialize + # should follow the cluster to 7007 + def map_7007(self): + self.nodes_cache = { + node_7007.name: node_7007} + self.slots_cache = {} + + for i in range(0, 16383): + self.slots_cache[i] = [node_7007] + + # Change initialize side effect for the second call + initialize.side_effect = map_7007 + + parse_response.side_effect = parse_response_mock + parse_response.successful_calls = 0 + parse_response.failed_calls = 0 + initialize.side_effect = initialize_mock + mocks['can_read'].return_value = False + mocks['send_command'].return_value = "MOCK_OK" + mocks['connect'].return_value = None + with patch.object(CommandsParser, 'initialize', + autospec=True) as cmd_parser_initialize: + + def cmd_init_mock(self, r): + self.commands = {'get': {'name': 'get', 'arity': 2, + 'flags': ['readonly', + 'fast'], + 'first_key_pos': 1, + 'last_key_pos': 1, + 'step_count': 1}} + + cmd_parser_initialize.side_effect = cmd_init_mock + + rc = _get_client( + RedisCluster, request, flushdb=False) + assert len(rc.get_nodes()) == 1 + assert rc.get_node(node_name=node_7006.name) is not \ + None + + rc.get('foo') + + # Cluster should now point to 7007, and there should be + # one failed and one successful call + assert len(rc.get_nodes()) == 1 + assert rc.get_node(node_name=node_7007.name) is not \ + None + assert rc.get_node(node_name=node_7006.name) is None + assert parse_response.failed_calls == 1 + assert parse_response.successful_calls == 1 + + def test_reading_from_replicas_in_round_robin(self): + with patch.multiple(Connection, send_command=DEFAULT, + read_response=DEFAULT, _connect=DEFAULT, + can_read=DEFAULT, on_connect=DEFAULT) as mocks: + with patch.object(Redis, 'parse_response') as parse_response: + def parse_response_mock_first(connection, *args, **options): + # Primary + assert connection.port == 7001 + parse_response.side_effect = parse_response_mock_second + return "MOCK_OK" + + def parse_response_mock_second(connection, *args, **options): + # Replica + assert connection.port == 7002 + parse_response.side_effect = parse_response_mock_third + return "MOCK_OK" + + def parse_response_mock_third(connection, *args, **options): + # Primary + assert connection.port == 7001 + return "MOCK_OK" + + # We don't need to create a real cluster connection but we + # do want RedisCluster.on_connect function to get called, + # so we'll mock some of the Connection's functions to allow it + parse_response.side_effect = parse_response_mock_first + mocks['send_command'].return_value = True + mocks['read_response'].return_value = "OK" + mocks['_connect'].return_value = True + mocks['can_read'].return_value = False + mocks['on_connect'].return_value = True + + # Create a cluster with reading from replications + read_cluster = get_mocked_redis_client(host=default_host, + port=default_port, + read_from_replicas=True) + assert read_cluster.read_from_replicas is True + # Check that we read from the slot's nodes in a round robin + # matter. + # 'foo' belongs to slot 12182 and the slot's nodes are: + # [(127.0.0.1,7001,primary), (127.0.0.1,7002,replica)] + read_cluster.get("foo") + read_cluster.get("foo") + read_cluster.get("foo") + mocks['send_command'].assert_has_calls([call('READONLY')]) + + def test_keyslot(self, r): + """ + Test that method will compute correct key in all supported cases + """ + assert r.keyslot("foo") == 12182 + assert r.keyslot("{foo}bar") == 12182 + assert r.keyslot("{foo}") == 12182 + assert r.keyslot(1337) == 4314 + + assert r.keyslot(125) == r.keyslot(b"125") + assert r.keyslot(125) == r.keyslot("\x31\x32\x35") + assert r.keyslot("大奖") == r.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96") + assert r.keyslot(u"大奖") == r.keyslot(b"\xe5\xa4\xa7\xe5\xa5\x96") + assert r.keyslot(1337.1234) == r.keyslot("1337.1234") + assert r.keyslot(1337) == r.keyslot("1337") + assert r.keyslot(b"abc") == r.keyslot("abc") + + def test_get_node_name(self): + assert get_node_name(default_host, default_port) == \ + "{0}:{1}".format(default_host, default_port) + + def test_all_nodes(self, r): + """ + Set a list of nodes and it should be possible to iterate over all + """ + nodes = [node for node in r.nodes_manager.nodes_cache.values()] + + for i, node in enumerate(r.get_nodes()): + assert node in nodes + + def test_all_nodes_masters(self, r): + """ + Set a list of nodes with random primaries/replicas config and it shold + be possible to iterate over all of them. + """ + nodes = [node for node in r.nodes_manager.nodes_cache.values() + if node.server_type == PRIMARY] + + for node in r.get_primaries(): + assert node in nodes + + @pytest.mark.filterwarnings("ignore:ClusterDownError") + def test_cluster_down_overreaches_retry_attempts(self): + """ + When ClusterDownError is thrown, test that we retry executing the + command as many times as configured in cluster_error_retry_attempts + and then raise the exception + """ + with patch.object(RedisCluster, '_execute_command') as execute_command: + def raise_cluster_down_error(target_node, *args, **kwargs): + execute_command.failed_calls += 1 + raise ClusterDownError( + 'CLUSTERDOWN The cluster is down. Use CLUSTER INFO for ' + 'more information') + + execute_command.side_effect = raise_cluster_down_error + + rc = get_mocked_redis_client(host=default_host, port=default_port) + + with pytest.raises(ClusterDownError): + rc.get("bar") + assert execute_command.failed_calls == \ + rc.cluster_error_retry_attempts + + @pytest.mark.filterwarnings("ignore:ConnectionError") + def test_connection_error_overreaches_retry_attempts(self): + """ + When ConnectionError is thrown, test that we retry executing the + command as many times as configured in cluster_error_retry_attempts + and then raise the exception + """ + with patch.object(RedisCluster, '_execute_command') as execute_command: + def raise_conn_error(target_node, *args, **kwargs): + execute_command.failed_calls += 1 + raise ConnectionError() + + execute_command.side_effect = raise_conn_error + + rc = get_mocked_redis_client(host=default_host, port=default_port) + + with pytest.raises(ConnectionError): + rc.get("bar") + assert execute_command.failed_calls == \ + rc.cluster_error_retry_attempts + + def test_user_on_connect_function(self, request): + """ + Test support in passing on_connect function by the user + """ + + def on_connect(connection): + assert connection is not None + + mock = Mock(side_effect=on_connect) + + _get_client(RedisCluster, request, redis_connect_func=mock) + assert mock.called is True + + +@skip_if_not_cluster_mode() +class TestClusterRedisCommands: + def test_case_insensitive_command_names(self, r): + assert r.cluster_response_callbacks['cluster addslots'] == \ + r.cluster_response_callbacks['CLUSTER ADDSLOTS'] + + def test_get_and_set(self, r): + # get and set can't be tested independently of each other + assert r.get('a') is None + byte_string = b'value' + integer = 5 + unicode_string = chr(3456) + 'abcd' + chr(3421) + assert r.set('byte_string', byte_string) + assert r.set('integer', 5) + assert r.set('unicode_string', unicode_string) + assert r.get('byte_string') == byte_string + assert r.get('integer') == str(integer).encode() + assert r.get('unicode_string').decode('utf-8') == unicode_string + + def test_mget_nonatomic(self, r): + assert r.mget_nonatomic([]) == [] + assert r.mget_nonatomic(['a', 'b']) == [None, None] + r['a'] = '1' + r['b'] = '2' + r['c'] = '3' + + assert (r.mget_nonatomic('a', 'other', 'b', 'c') == + [b'1', None, b'2', b'3']) + + def test_mset_nonatomic(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + assert r.mset_nonatomic(d) + for k, v in d.items(): + assert r[k] == v + + def test_dbsize(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + assert r.mset_nonatomic(d) + assert r.dbsize() == len(d) + + def test_config_set(self, r): + assert r.config_set('slowlog-log-slower-than', 0) + + def test_client_setname(self, r): + r.client_setname('redis_py_test') + res = r.client_getname() + for client_name in res.values(): + assert client_name == 'redis_py_test' + + def test_exists(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + r.mset_nonatomic(d) + assert r.exists(*d.keys()) == len(d) + + def test_delete(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + r.mset_nonatomic(d) + assert r.delete(*d.keys()) == len(d) + assert r.delete(*d.keys()) == 0 + + def test_touch(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + r.mset_nonatomic(d) + assert r.touch(*d.keys()) == len(d) + + def test_unlink(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + r.mset_nonatomic(d) + assert r.unlink(*d.keys()) == len(d) + # Unlink is non-blocking so we sleep before + # verifying the deletion + sleep(0.1) + assert r.unlink(*d.keys()) == 0 + + def test_pubsub_channels_merge_results(self, r): + nodes = r.get_nodes() + channels = [] + i = 0 + for node in nodes: + channel = "foo{0}".format(i) + # We will create different pubsub clients where each one is + # connected to a different node + p = r.pubsub(node) + p.subscribe(channel) + b_channel = channel.encode('utf-8') + channels.append(b_channel) + # Assert that each node returns only the channel it subscribed to + sub_channels = node.redis_connection.pubsub_channels() + if not sub_channels: + # Try again after a short sleep + sleep(0.3) + sub_channels = node.redis_connection.pubsub_channels() + assert sub_channels == [b_channel] + i += 1 + # Assert that the cluster's pubsub_channels function returns ALL of + # the cluster's channels + result = r.pubsub_channels() + result.sort() + assert result == channels + + def test_pubsub_numsub_merge_results(self, r): + nodes = r.get_nodes() + channel = "foo" + b_channel = channel.encode('utf-8') + for node in nodes: + # We will create different pubsub clients where each one is + # connected to a different node + p = r.pubsub(node) + p.subscribe(channel) + # Assert that each node returns that only one client is subscribed + sub_chann_num = node.redis_connection.pubsub_numsub(channel) + if sub_chann_num == [(b_channel, 0)]: + sleep(0.3) + sub_chann_num = node.redis_connection.pubsub_numsub(channel) + assert sub_chann_num == [(b_channel, 1)] + # Assert that the cluster's pubsub_numsub function returns ALL clients + # subscribed to this channel in the entire cluster + assert r.pubsub_numsub(channel) == [(b_channel, len(nodes))] + + def test_pubsub_numpat_merge_results(self, r): + nodes = r.get_nodes() + pattern = "foo*" + for node in nodes: + # We will create different pubsub clients where each one is + # connected to a different node + p = r.pubsub(node) + p.psubscribe(pattern) + # Assert that each node returns that only one client is subscribed + sub_num_pat = node.redis_connection.pubsub_numpat() + if sub_num_pat == 0: + sleep(0.3) + sub_num_pat = node.redis_connection.pubsub_numpat() + assert sub_num_pat == 1 + # Assert that the cluster's pubsub_numsub function returns ALL clients + # subscribed to this channel in the entire cluster + assert r.pubsub_numpat() == len(nodes) + + def test_cluster_slots(self, r): + mock_all_nodes_resp(r, default_cluster_slots) + cluster_slots = r.cluster_slots() + assert isinstance(cluster_slots, dict) + assert len(default_cluster_slots) == len(cluster_slots) + assert cluster_slots.get((0, 8191)) is not None + assert cluster_slots.get((0, 8191)).get('primary') == \ + ('127.0.0.1', 7000) + + def test_cluster_addslots(self, r): + node = r.get_random_node() + mock_node_resp(node, 'OK') + assert r.cluster_addslots(node, 1, 2, 3) is True + + def test_cluster_countkeysinslot(self, r): + node = r.nodes_manager.get_node_from_slot(1) + mock_node_resp(node, 2) + assert r.cluster_countkeysinslot(1) == 2 + + def test_cluster_count_failure_report(self, r): + mock_all_nodes_resp(r, 0) + assert r.cluster_count_failure_report('node_0') == 0 + + def test_cluster_delslots(self): + cluster_slots = [ + [ + 0, 8191, + ['127.0.0.1', 7000, 'node_0'], + ], + [ + 8192, 16383, + ['127.0.0.1', 7001, 'node_1'], + ] + ] + r = get_mocked_redis_client(host=default_host, port=default_port, + cluster_slots=cluster_slots) + mock_all_nodes_resp(r, 'OK') + node0 = r.get_node(default_host, 7000) + node1 = r.get_node(default_host, 7001) + assert r.cluster_delslots(0, 8192) == [True, True] + assert node0.redis_connection.connection.read_response.called + assert node1.redis_connection.connection.read_response.called + + def test_cluster_failover(self, r): + node = r.get_random_node() + mock_node_resp(node, 'OK') + assert r.cluster_failover(node) is True + assert r.cluster_failover(node, 'FORCE') is True + assert r.cluster_failover(node, 'TAKEOVER') is True + with pytest.raises(RedisError): + r.cluster_failover(node, 'FORCT') + + def test_cluster_info(self, r): + info = r.cluster_info() + assert isinstance(info, dict) + assert info['cluster_state'] == 'ok' + + def test_cluster_keyslot(self, r): + mock_all_nodes_resp(r, 12182) + assert r.cluster_keyslot('foo') == 12182 + + def test_cluster_meet(self, r): + node = r.get_random_node() + mock_node_resp(node, 'OK') + assert r.cluster_meet(node, '127.0.0.1', 6379) is True + + def test_cluster_nodes(self, r): + response = ( + 'c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 ' + 'slave aa90da731f673a99617dfe930306549a09f83a6b 0 ' + '1447836263059 5 connected\n' + '9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 ' + 'master - 0 1447836264065 0 connected\n' + 'aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 ' + 'myself,master - 0 0 2 connected 5461-10922\n' + '1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 ' + 'slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 ' + '1447836262556 3 connected\n' + '4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 ' + 'master - 0 1447836262555 7 connected 0-5460\n' + '19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 ' + 'master - 0 1447836263562 3 connected 10923-16383\n' + 'fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 ' + 'master,fail - 1447829446956 1447829444948 1 disconnected\n' + ) + mock_all_nodes_resp(r, response) + nodes = r.cluster_nodes() + assert len(nodes) == 7 + assert nodes.get('172.17.0.7:7006') is not None + assert nodes.get('172.17.0.7:7006').get('node_id') == \ + "c8253bae761cb1ecb2b61857d85dfe455a0fec8b" + + def test_cluster_replicate(self, r): + node = r.get_random_node() + all_replicas = r.get_replicas() + mock_all_nodes_resp(r, 'OK') + assert r.cluster_replicate(node, 'c8253bae761cb61857d') is True + results = r.cluster_replicate(all_replicas, 'c8253bae761cb61857d') + for res in results.values(): + assert res is True + + def test_cluster_reset(self, r): + node = r.get_random_node() + all_nodes = r.get_nodes() + mock_all_nodes_resp(r, 'OK') + assert r.cluster_reset(node) is True + assert r.cluster_reset(node, False) is True + all_results = r.cluster_reset(all_nodes, False) + for res in all_results.values(): + assert res is True + + def test_cluster_save_config(self, r): + node = r.get_random_node() + all_nodes = r.get_nodes() + mock_all_nodes_resp(r, 'OK') + assert r.cluster_save_config(node) is True + all_results = r.cluster_save_config(all_nodes) + for res in all_results.values(): + assert res is True + + def test_cluster_get_keys_in_slot(self, r): + response = [b'{foo}1', b'{foo}2'] + node = r.nodes_manager.get_node_from_slot(12182) + mock_node_resp(node, response) + keys = r.cluster_get_keys_in_slot(12182, 4) + assert keys == response + + def test_cluster_set_config_epoch(self, r): + node = r.get_random_node() + all_nodes = r.get_nodes() + mock_all_nodes_resp(r, 'OK') + assert r.cluster_set_config_epoch(node, 3) is True + all_results = r.cluster_set_config_epoch(all_nodes, 3) + for res in all_results.values(): + assert res is True + + def test_cluster_setslot(self, r): + node = r.get_random_node() + mock_node_resp(node, 'OK') + assert r.cluster_setslot(node, 'node_0', 1218, 'IMPORTING') is True + assert r.cluster_setslot(node, 'node_0', 1218, 'NODE') is True + assert r.cluster_setslot(node, 'node_0', 1218, 'MIGRATING') is True + with pytest.raises(RedisError): + r.cluster_failover(node, 'STABLE') + with pytest.raises(RedisError): + r.cluster_failover(node, 'STATE') + + def test_cluster_setslot_stable(self, r): + node = r.nodes_manager.get_node_from_slot(12182) + mock_node_resp(node, 'OK') + assert r.cluster_setslot_stable(12182) is True + assert node.redis_connection.connection.read_response.called + + def test_cluster_replicas(self, r): + response = [b'01eca22229cf3c652b6fca0d09ff6941e0d2e3 ' + b'127.0.0.1:6377@16377 slave ' + b'52611e796814b78e90ad94be9d769a4f668f9a 0 ' + b'1634550063436 4 connected', + b'r4xfga22229cf3c652b6fca0d09ff69f3e0d4d ' + b'127.0.0.1:6378@16378 slave ' + b'52611e796814b78e90ad94be9d769a4f668f9a 0 ' + b'1634550063436 4 connected'] + mock_all_nodes_resp(r, response) + replicas = r.cluster_replicas('52611e796814b78e90ad94be9d769a4f668f9a') + assert replicas.get('127.0.0.1:6377') is not None + assert replicas.get('127.0.0.1:6378') is not None + assert replicas.get('127.0.0.1:6378').get('node_id') == \ + 'r4xfga22229cf3c652b6fca0d09ff69f3e0d4d' + + def test_readonly(self): + r = get_mocked_redis_client(host=default_host, port=default_port) + node = r.get_random_node() + all_replicas = r.get_replicas() + mock_all_nodes_resp(r, 'OK') + assert r.readonly(node) is True + all_replicas_results = r.readonly() + for res in all_replicas_results.values(): + assert res is True + for replica in all_replicas: + assert replica.redis_connection.connection.read_response.called + + def test_readwrite(self): + r = get_mocked_redis_client(host=default_host, port=default_port) + node = r.get_random_node() + mock_all_nodes_resp(r, 'OK') + all_replicas = r.get_replicas() + assert r.readwrite(node) is True + all_replicas_results = r.readwrite() + for res in all_replicas_results.values(): + assert res is True + for replica in all_replicas: + assert replica.redis_connection.connection.read_response.called + + def test_bgsave(self, r): + assert r.bgsave() + sleep(0.3) + assert r.bgsave(True) + + def test_info(self, r): + # Map keys to same slot + r.set('x{1}', 1) + r.set('y{1}', 2) + r.set('z{1}', 3) + # Get node that handles the slot + slot = r.keyslot('x{1}') + node = r.nodes_manager.get_node_from_slot(slot) + # Run info on that node + info = r.info(target_nodes=node) + assert isinstance(info, dict) + assert info['db0']['keys'] == 3 + + def test_slowlog_get(self, r, slowlog): + assert r.slowlog_reset() + unicode_string = chr(3456) + 'abcd' + chr(3421) + r.get(unicode_string) + + slot = r.keyslot(unicode_string) + node = r.nodes_manager.get_node_from_slot(slot) + slowlog = r.slowlog_get(target_nodes=node) + assert isinstance(slowlog, list) + commands = [log['command'] for log in slowlog] + + get_command = b' '.join((b'GET', unicode_string.encode('utf-8'))) + assert get_command in commands + assert b'SLOWLOG RESET' in commands + + # the order should be ['GET ', 'SLOWLOG RESET'], + # but if other clients are executing commands at the same time, there + # could be commands, before, between, or after, so just check that + # the two we care about are in the appropriate order. + assert commands.index(get_command) < commands.index(b'SLOWLOG RESET') + + # make sure other attributes are typed correctly + assert isinstance(slowlog[0]['start_time'], int) + assert isinstance(slowlog[0]['duration'], int) + + def test_slowlog_get_limit(self, r, slowlog): + assert r.slowlog_reset() + r.get('foo') + node = r.nodes_manager.get_node_from_slot(key_slot(b'foo')) + slowlog = r.slowlog_get(1, target_nodes=node) + assert isinstance(slowlog, list) + # only one command, based on the number we passed to slowlog_get() + assert len(slowlog) == 1 + + def test_slowlog_length(self, r, slowlog): + r.get('foo') + node = r.nodes_manager.get_node_from_slot(key_slot(b'foo')) + slowlog_len = r.slowlog_len(target_nodes=node) + assert isinstance(slowlog_len, int) + + def test_time(self, r): + t = r.time(target_nodes=r.get_primaries()[0]) + assert len(t) == 2 + assert isinstance(t[0], int) + assert isinstance(t[1], int) + + @skip_if_server_version_lt('4.0.0') + def test_memory_usage(self, r): + r.set('foo', 'bar') + assert isinstance(r.memory_usage('foo'), int) + + @skip_if_server_version_lt('4.0.0') + def test_memory_malloc_stats(self, r): + assert r.memory_malloc_stats() + + @skip_if_server_version_lt('4.0.0') + def test_memory_stats(self, r): + # put a key into the current db to make sure that "db." + # has data + r.set('foo', 'bar') + node = r.nodes_manager.get_node_from_slot(key_slot(b'foo')) + stats = r.memory_stats(target_nodes=node) + assert isinstance(stats, dict) + for key, value in stats.items(): + if key.startswith('db.'): + assert isinstance(value, dict) + + @skip_if_server_version_lt('4.0.0') + def test_memory_help(self, r): + with pytest.raises(NotImplementedError): + r.memory_help() + + @skip_if_server_version_lt('4.0.0') + def test_memory_doctor(self, r): + with pytest.raises(NotImplementedError): + r.memory_doctor() + + def test_object(self, r): + r['a'] = 'foo' + assert isinstance(r.object('refcount', 'a'), int) + assert isinstance(r.object('idletime', 'a'), int) + assert r.object('encoding', 'a') in (b'raw', b'embstr') + assert r.object('idletime', 'invalid-key') is None + + def test_lastsave(self, r): + node = r.get_primaries()[0] + assert isinstance(r.lastsave(target_nodes=node), + datetime.datetime) + + def test_echo(self, r): + node = r.get_primaries()[0] + assert r.echo('foo bar', node) == b'foo bar' + + @skip_if_server_version_lt('1.0.0') + def test_debug_segfault(self, r): + with pytest.raises(NotImplementedError): + r.debug_segfault() + + def test_config_resetstat(self, r): + node = r.get_primaries()[0] + r.ping(target_nodes=node) + prior_commands_processed = \ + int(r.info(target_nodes=node)['total_commands_processed']) + assert prior_commands_processed >= 1 + r.config_resetstat(target_nodes=node) + reset_commands_processed = \ + int(r.info(target_nodes=node)['total_commands_processed']) + assert reset_commands_processed < prior_commands_processed + + @skip_if_server_version_lt('6.2.0') + def test_client_trackinginfo(self, r): + node = r.get_primaries()[0] + res = r.client_trackinginfo(target_nodes=node) + assert len(res) > 2 + assert 'prefixes' in res + + @skip_if_server_version_lt('2.9.50') + def test_client_pause(self, r): + node = r.get_primaries()[0] + assert r.client_pause(1, target_nodes=node) + assert r.client_pause(timeout=1, target_nodes=node) + with pytest.raises(RedisError): + r.client_pause(timeout='not an integer', target_nodes=node) + + @skip_if_server_version_lt('6.2.0') + def test_client_unpause(self, r): + assert r.client_unpause() + + @skip_if_server_version_lt('5.0.0') + def test_client_id(self, r): + node = r.get_primaries()[0] + assert r.client_id(target_nodes=node) > 0 + + @skip_if_server_version_lt('5.0.0') + def test_client_unblock(self, r): + node = r.get_primaries()[0] + myid = r.client_id(target_nodes=node) + assert not r.client_unblock(myid, target_nodes=node) + assert not r.client_unblock(myid, error=True, target_nodes=node) + assert not r.client_unblock(myid, error=False, target_nodes=node) + + @skip_if_server_version_lt('6.0.0') + def test_client_getredir(self, r): + node = r.get_primaries()[0] + assert isinstance(r.client_getredir(target_nodes=node), int) + assert r.client_getredir(target_nodes=node) == -1 + + @skip_if_server_version_lt('6.2.0') + def test_client_info(self, r): + node = r.get_primaries()[0] + info = r.client_info(target_nodes=node) + assert isinstance(info, dict) + assert 'addr' in info + + @skip_if_server_version_lt('2.6.9') + def test_client_kill(self, r, r2): + node = r.get_primaries()[0] + r.client_setname('redis-py-c1') + r2.client_setname('redis-py-c2') + clients = [client for client in r.client_list()[node.name] + if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + assert len(clients) == 2 + clients_by_name = dict([(client.get('name'), client) + for client in clients]) + + client_addr = clients_by_name['redis-py-c2'].get('addr') + assert r.client_kill(client_addr, target_nodes=node) is True + + clients = [client for client in r.client_list()[node.name] + if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + assert len(clients) == 1 + assert clients[0].get('name') == 'redis-py-c1' + + +@skip_if_not_cluster_mode() +class TestNodesManager: + def test_load_balancer(self, r): + n_manager = r.nodes_manager + lb = n_manager.read_load_balancer + slot_1 = 1257 + slot_2 = 8975 + node_1 = ClusterNode(default_host, 6379, PRIMARY) + node_2 = ClusterNode(default_host, 6378, REPLICA) + node_3 = ClusterNode(default_host, 6377, REPLICA) + node_4 = ClusterNode(default_host, 6376, PRIMARY) + node_5 = ClusterNode(default_host, 6375, REPLICA) + n_manager.slots_cache = { + slot_1: [node_1, node_2, node_3], + slot_2: [node_4, node_5] + } + primary1_name = n_manager.slots_cache[slot_1][0].name + primary2_name = n_manager.slots_cache[slot_2][0].name + list1_size = len(n_manager.slots_cache[slot_1]) + list2_size = len(n_manager.slots_cache[slot_2]) + # slot 1 + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary1_name, list1_size) == 1 + assert lb.get_server_index(primary1_name, list1_size) == 2 + assert lb.get_server_index(primary1_name, list1_size) == 0 + # slot 2 + assert lb.get_server_index(primary2_name, list2_size) == 0 + assert lb.get_server_index(primary2_name, list2_size) == 1 + assert lb.get_server_index(primary2_name, list2_size) == 0 + + lb.reset() + assert lb.get_server_index(primary1_name, list1_size) == 0 + assert lb.get_server_index(primary2_name, list2_size) == 0 + + def test_init_slots_cache_not_all_slots_covered(self): + """ + Test that if not all slots are covered it should raise an exception + """ + # Missing slot 5460 + cluster_slots = [ + [0, 5459, ['127.0.0.1', 7000], ['127.0.0.1', 7003]], + [5461, 10922, ['127.0.0.1', 7001], + ['127.0.0.1', 7004]], + [10923, 16383, ['127.0.0.1', 7002], + ['127.0.0.1', 7005]], + ] + with pytest.raises(RedisClusterException) as ex: + get_mocked_redis_client(host=default_host, port=default_port, + cluster_slots=cluster_slots) + assert str(ex.value).startswith( + "All slots are not covered after query all startup_nodes.") + + def test_init_slots_cache_not_require_full_coverage_error(self): + """ + When require_full_coverage is set to False and not all slots are + covered, if one of the nodes has 'cluster-require_full_coverage' + config set to 'yes' the cluster initialization should fail + """ + # Missing slot 5460 + cluster_slots = [ + [0, 5459, ['127.0.0.1', 7000], ['127.0.0.1', 7003]], + [5461, 10922, ['127.0.0.1', 7001], + ['127.0.0.1', 7004]], + [10923, 16383, ['127.0.0.1', 7002], + ['127.0.0.1', 7005]], + ] + + with pytest.raises(RedisClusterException): + get_mocked_redis_client(host=default_host, port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=False, + coverage_result='yes') + + def test_init_slots_cache_not_require_full_coverage_success(self): + """ + When require_full_coverage is set to False and not all slots are + covered, if all of the nodes has 'cluster-require_full_coverage' + config set to 'no' the cluster initialization should succeed + """ + # Missing slot 5460 + cluster_slots = [ + [0, 5459, ['127.0.0.1', 7000], ['127.0.0.1', 7003]], + [5461, 10922, ['127.0.0.1', 7001], + ['127.0.0.1', 7004]], + [10923, 16383, ['127.0.0.1', 7002], + ['127.0.0.1', 7005]], + ] + + rc = get_mocked_redis_client(host=default_host, port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=False, + coverage_result='no') + + assert 5460 not in rc.nodes_manager.slots_cache + + def test_init_slots_cache_not_require_full_coverage_skips_check(self): + """ + Test that when require_full_coverage is set to False and + skip_full_coverage_check is set to true, the cluster initialization + succeed without checking the nodes' Redis configurations + """ + # Missing slot 5460 + cluster_slots = [ + [0, 5459, ['127.0.0.1', 7000], ['127.0.0.1', 7003]], + [5461, 10922, ['127.0.0.1', 7001], + ['127.0.0.1', 7004]], + [10923, 16383, ['127.0.0.1', 7002], + ['127.0.0.1', 7005]], + ] + + with patch.object(NodesManager, + 'cluster_require_full_coverage') as conf_check_mock: + rc = get_mocked_redis_client(host=default_host, port=default_port, + cluster_slots=cluster_slots, + require_full_coverage=False, + skip_full_coverage_check=True, + coverage_result='no') + + assert conf_check_mock.called is False + assert 5460 not in rc.nodes_manager.slots_cache + + def test_init_slots_cache(self): + """ + Test that slots cache can in initialized and all slots are covered + """ + good_slots_resp = [ + [0, 5460, ['127.0.0.1', 7000], ['127.0.0.2', 7003]], + [5461, 10922, ['127.0.0.1', 7001], ['127.0.0.2', 7004]], + [10923, 16383, ['127.0.0.1', 7002], ['127.0.0.2', 7005]], + ] + + rc = get_mocked_redis_client(host=default_host, port=default_port, + cluster_slots=good_slots_resp) + n_manager = rc.nodes_manager + assert len(n_manager.slots_cache) == REDIS_CLUSTER_HASH_SLOTS + for slot_info in good_slots_resp: + all_hosts = ['127.0.0.1', '127.0.0.2'] + all_ports = [7000, 7001, 7002, 7003, 7004, 7005] + slot_start = slot_info[0] + slot_end = slot_info[1] + for i in range(slot_start, slot_end + 1): + assert len(n_manager.slots_cache[i]) == len(slot_info[2:]) + assert n_manager.slots_cache[i][0].host in all_hosts + assert n_manager.slots_cache[i][1].host in all_hosts + assert n_manager.slots_cache[i][0].port in all_ports + assert n_manager.slots_cache[i][1].port in all_ports + + assert len(n_manager.nodes_cache) == 6 + + def test_empty_startup_nodes(self): + """ + It should not be possible to create a node manager with no nodes + specified + """ + with pytest.raises(RedisClusterException): + NodesManager([]) + + def test_wrong_startup_nodes_type(self): + """ + If something other then a list type itteratable is provided it should + fail + """ + with pytest.raises(RedisClusterException): + NodesManager({}) + + def test_init_slots_cache_slots_collision(self, request): + """ + Test that if 2 nodes do not agree on the same slots setup it should + raise an error. In this test both nodes will say that the first + slots block should be bound to different servers. + """ + with patch.object(NodesManager, + 'create_redis_node') as create_redis_node: + def create_mocked_redis_node(host, port, **kwargs): + """ + Helper function to return custom slots cache data from + different redis nodes + """ + if port == 7000: + result = [ + [ + 0, + 5460, + ['127.0.0.1', 7000], + ['127.0.0.1', 7003], + ], + [ + 5461, + 10922, + ['127.0.0.1', 7001], + ['127.0.0.1', 7004], + ], + ] + + elif port == 7001: + result = [ + [ + 0, + 5460, + ['127.0.0.1', 7001], + ['127.0.0.1', 7003], + ], + [ + 5461, + 10922, + ['127.0.0.1', 7000], + ['127.0.0.1', 7004], + ], + ] + else: + result = [] + + r_node = Redis( + host=host, + port=port + ) + + orig_execute_command = r_node.execute_command + + def execute_command(*args, **kwargs): + if args[0] == 'CLUSTER SLOTS': + return result + elif args[1] == 'cluster-require-full-coverage': + return {'cluster-require-full-coverage': 'yes'} + else: + return orig_execute_command(*args, **kwargs) + + r_node.execute_command = execute_command + return r_node + + create_redis_node.side_effect = create_mocked_redis_node + + with pytest.raises(RedisClusterException) as ex: + node_1 = ClusterNode('127.0.0.1', 7000) + node_2 = ClusterNode('127.0.0.1', 7001) + RedisCluster(startup_nodes=[node_1, node_2]) + assert str(ex.value).startswith( + "startup_nodes could not agree on a valid slots cache"), str( + ex.value) + + def test_cluster_one_instance(self): + """ + If the cluster exists of only 1 node then there is some hacks that must + be validated they work. + """ + node = ClusterNode(default_host, default_port) + cluster_slots = [[0, 16383, ['', default_port]]] + rc = get_mocked_redis_client(startup_nodes=[node], + cluster_slots=cluster_slots) + + n = rc.nodes_manager + assert len(n.nodes_cache) == 1 + n_node = rc.get_node(node_name=node.name) + assert n_node is not None + assert n_node == node + assert n_node.server_type == PRIMARY + assert len(n.slots_cache) == REDIS_CLUSTER_HASH_SLOTS + for i in range(0, REDIS_CLUSTER_HASH_SLOTS): + assert n.slots_cache[i] == [n_node] + + def test_init_with_down_node(self): + """ + If I can't connect to one of the nodes, everything should still work. + But if I can't connect to any of the nodes, exception should be thrown. + """ + with patch.object(NodesManager, + 'create_redis_node') as create_redis_node: + def create_mocked_redis_node(host, port, **kwargs): + if port == 7000: + raise ConnectionError('mock connection error for 7000') + + r_node = Redis(host=host, port=port, decode_responses=True) + + def execute_command(*args, **kwargs): + if args[0] == 'CLUSTER SLOTS': + return [ + [ + 0, 8191, + ['127.0.0.1', 7001, 'node_1'], + ], + [ + 8192, 16383, + ['127.0.0.1', 7002, 'node_2'], + ] + ] + elif args[1] == 'cluster-require-full-coverage': + return {'cluster-require-full-coverage': 'yes'} + + r_node.execute_command = execute_command + + return r_node + + create_redis_node.side_effect = create_mocked_redis_node + + node_1 = ClusterNode('127.0.0.1', 7000) + node_2 = ClusterNode('127.0.0.1', 7001) + + # If all startup nodes fail to connect, connection error should be + # thrown + with pytest.raises(RedisClusterException) as e: + RedisCluster(startup_nodes=[node_1]) + assert 'Redis Cluster cannot be connected' in str(e.value) + + with patch.object(CommandsParser, 'initialize', + autospec=True) as cmd_parser_initialize: + + def cmd_init_mock(self, r): + self.commands = {'get': {'name': 'get', 'arity': 2, + 'flags': ['readonly', + 'fast'], + 'first_key_pos': 1, + 'last_key_pos': 1, + 'step_count': 1}} + + cmd_parser_initialize.side_effect = cmd_init_mock + # When at least one startup node is reachable, the cluster + # initialization should succeeds + rc = RedisCluster(startup_nodes=[node_1, node_2]) + assert rc.get_node(host=default_host, port=7001) is not None + assert rc.get_node(host=default_host, port=7002) is not None diff --git a/tests/test_commands.py b/tests/test_commands.py index 6d65931539..998bc0f0e6 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -8,9 +8,10 @@ from redis.client import parse_info from redis import exceptions - +from redis.commands import CommandsParser from .conftest import ( _get_client, + skip_if_cluster_mode, skip_if_server_version_gte, skip_if_server_version_lt, skip_unless_arch_bits, @@ -46,6 +47,7 @@ def get_stream_message(client, stream, message_id): # RESPONSE CALLBACKS +@skip_if_cluster_mode() class TestResponseCallbacks: "Tests for the response callback system" @@ -60,6 +62,7 @@ def test_case_insensitive_command_names(self, r): assert r.response_callbacks['del'] == r.response_callbacks['DEL'] +@skip_if_cluster_mode() class TestRedisCommands: def test_command_on_invalid_key_type(self, r): r.lpush('a', '1') @@ -122,6 +125,7 @@ def test_acl_getuser_setuser(self, r, request): def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) # test enabled=False @@ -215,6 +219,7 @@ def test_acl_list(self, r, request): def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) assert r.acl_setuser(username, enabled=False, reset=True) @@ -262,6 +267,7 @@ def test_acl_setuser_categories_without_prefix_fails(self, r, request): def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) with pytest.raises(exceptions.DataError): @@ -273,6 +279,7 @@ def test_acl_setuser_commands_without_prefix_fails(self, r, request): def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) with pytest.raises(exceptions.DataError): @@ -284,6 +291,7 @@ def test_acl_setuser_add_passwords_and_nopass_fails(self, r, request): def teardown(): r.acl_deluser(username) + request.addfinalizer(teardown) with pytest.raises(exceptions.DataError): @@ -593,6 +601,7 @@ def parse_response(connection, command_name, **options): # Complexity info stored as fourth item in list response.insert(3, COMPLEXITY_STATEMENT) return r.response_callbacks[command_name](responses, **options) + r.parse_response = parse_response # test @@ -1195,22 +1204,22 @@ def test_stralgo_lcs(self, r): # test other labels assert r.stralgo('LCS', value1, value2, len=True) == len(res) assert r.stralgo('LCS', value1, value2, idx=True) == \ - { - 'len': len(res), - 'matches': [[(4, 7), (5, 8)], [(2, 3), (0, 1)]] - } + { + 'len': len(res), + 'matches': [[(4, 7), (5, 8)], [(2, 3), (0, 1)]] + } assert r.stralgo('LCS', value1, value2, idx=True, withmatchlen=True) == \ - { - 'len': len(res), - 'matches': [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]] - } + { + 'len': len(res), + 'matches': [[4, (4, 7), (5, 8)], [2, (2, 3), (0, 1)]] + } assert r.stralgo('LCS', value1, value2, idx=True, minmatchlen=4, withmatchlen=True) == \ - { - 'len': len(res), - 'matches': [[4, (4, 7), (5, 8)]] - } + { + 'len': len(res), + 'matches': [[4, (4, 7), (5, 8)]] + } @skip_if_server_version_lt('6.0.0') def test_stralgo_negative(self, r): @@ -1758,16 +1767,16 @@ def test_zinter(self, r): r.zinter(['a', 'b', 'c'], aggregate='foo', withscores=True) # aggregate with SUM assert r.zinter(['a', 'b', 'c'], withscores=True) \ - == [(b'a3', 8), (b'a1', 9)] + == [(b'a3', 8), (b'a1', 9)] # aggregate with MAX assert r.zinter(['a', 'b', 'c'], aggregate='MAX', withscores=True) \ - == [(b'a3', 5), (b'a1', 6)] + == [(b'a3', 5), (b'a1', 6)] # aggregate with MIN assert r.zinter(['a', 'b', 'c'], aggregate='MIN', withscores=True) \ - == [(b'a1', 1), (b'a3', 1)] + == [(b'a1', 1), (b'a3', 1)] # with weights assert r.zinter({'a': 1, 'b': 2, 'c': 3}, withscores=True) \ - == [(b'a3', 20), (b'a1', 23)] + == [(b'a3', 20), (b'a1', 23)] def test_zinterstore_sum(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) @@ -2059,14 +2068,14 @@ def test_zunion(self, r): assert r.zunion(['a', 'b', 'c'], withscores=True) == \ [(b'a2', 3), (b'a4', 4), (b'a3', 8), (b'a1', 9)] # max - assert r.zunion(['a', 'b', 'c'], aggregate='MAX', withscores=True)\ - == [(b'a2', 2), (b'a4', 4), (b'a3', 5), (b'a1', 6)] + assert r.zunion(['a', 'b', 'c'], aggregate='MAX', withscores=True) \ + == [(b'a2', 2), (b'a4', 4), (b'a3', 5), (b'a1', 6)] # min - assert r.zunion(['a', 'b', 'c'], aggregate='MIN', withscores=True)\ - == [(b'a1', 1), (b'a2', 1), (b'a3', 1), (b'a4', 4)] + assert r.zunion(['a', 'b', 'c'], aggregate='MIN', withscores=True) \ + == [(b'a1', 1), (b'a2', 1), (b'a3', 1), (b'a4', 4)] # with weight - assert r.zunion({'a': 1, 'b': 2, 'c': 3}, withscores=True)\ - == [(b'a2', 5), (b'a4', 12), (b'a3', 20), (b'a1', 23)] + assert r.zunion({'a': 1, 'b': 2, 'c': 3}, withscores=True) \ + == [(b'a2', 5), (b'a4', 12), (b'a3', 20), (b'a1', 23)] def test_zunionstore_sum(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) @@ -2927,10 +2936,10 @@ def test_xautoclaim(self, r): # which only returns message ids assert r.xautoclaim(stream, group, consumer1, min_idle_time=0, start_id=0, justid=True) == \ - [message_id1, message_id2] + [message_id1, message_id2] assert r.xautoclaim(stream, group, consumer1, min_idle_time=0, start_id=message_id2, justid=True) == \ - [message_id2] + [message_id2] @skip_if_server_version_lt('6.2.0') def test_xautoclaim_negative(self, r): @@ -3511,51 +3520,51 @@ def test_bitfield_operations(self, r): # comments show affected bits bf = r.bitfield('a') resp = (bf - .set('u8', 8, 255) # 00000000 11111111 - .get('u8', 0) # 00000000 - .get('u4', 8) # 1111 - .get('u4', 12) # 1111 - .get('u4', 13) # 111 0 + .set('u8', 8, 255) # 00000000 11111111 + .get('u8', 0) # 00000000 + .get('u4', 8) # 1111 + .get('u4', 12) # 1111 + .get('u4', 13) # 111 0 .execute()) assert resp == [0, 0, 15, 15, 14] # .set() returns the previous value... resp = (bf - .set('u8', 4, 1) # 0000 0001 - .get('u16', 0) # 00000000 00011111 - .set('u16', 0, 0) # 00000000 00000000 + .set('u8', 4, 1) # 0000 0001 + .get('u16', 0) # 00000000 00011111 + .set('u16', 0, 0) # 00000000 00000000 .execute()) assert resp == [15, 31, 31] # incrby adds to the value resp = (bf .incrby('u8', 8, 254) # 00000000 11111110 - .incrby('u8', 8, 1) # 00000000 11111111 - .get('u16', 0) # 00000000 11111111 + .incrby('u8', 8, 1) # 00000000 11111111 + .get('u16', 0) # 00000000 11111111 .execute()) assert resp == [254, 255, 255] # Verify overflow protection works as a method: r.delete('a') resp = (bf - .set('u8', 8, 254) # 00000000 11111110 + .set('u8', 8, 254) # 00000000 11111110 .overflow('fail') - .incrby('u8', 8, 2) # incrby 2 would overflow, None returned - .incrby('u8', 8, 1) # 00000000 11111111 - .incrby('u8', 8, 1) # incrby 1 would overflow, None returned - .get('u16', 0) # 00000000 11111111 + .incrby('u8', 8, 2) # incrby 2 would overflow, None returned + .incrby('u8', 8, 1) # 00000000 11111111 + .incrby('u8', 8, 1) # incrby 1 would overflow, None returned + .get('u16', 0) # 00000000 11111111 .execute()) assert resp == [0, None, 255, None, 255] # Verify overflow protection works as arg to incrby: r.delete('a') resp = (bf - .set('u8', 8, 255) # 00000000 11111111 - .incrby('u8', 8, 1) # 00000000 00000000 wrap default - .set('u8', 8, 255) # 00000000 11111111 + .set('u8', 8, 255) # 00000000 11111111 + .incrby('u8', 8, 1) # 00000000 00000000 wrap default + .set('u8', 8, 255) # 00000000 11111111 .incrby('u8', 8, 1, 'FAIL') # 00000000 11111111 fail - .incrby('u8', 8, 1) # 00000000 11111111 still fail - .get('u16', 0) # 00000000 11111111 + .incrby('u8', 8, 1) # 00000000 11111111 still fail + .get('u16', 0) # 00000000 11111111 .execute()) assert resp == [0, 0, 0, None, None, 255] @@ -3563,9 +3572,9 @@ def test_bitfield_operations(self, r): r.delete('a') bf = r.bitfield('a', default_overflow='FAIL') resp = (bf - .set('u8', 8, 255) # 00000000 11111111 - .incrby('u8', 8, 1) # 00000000 11111111 fail default - .get('u16', 0) # 00000000 11111111 + .set('u8', 8, 255) # 00000000 11111111 + .incrby('u8', 8, 1) # 00000000 11111111 fail default + .get('u16', 0) # 00000000 11111111 .execute()) assert resp == [0, None, 255] @@ -3672,6 +3681,7 @@ def test_replicaof(self, r): assert r.replicaof("NO", "ONE") +@skip_if_cluster_mode() class TestBinarySave: def test_binary_get_set(self, r): @@ -3757,3 +3767,60 @@ def test_floating_point_encoding(self, r): timestamp = 1349673917.939762 r.zadd('a', {'a1': timestamp}) assert r.zscore('a', 'a1') == timestamp + + +class TestCommandsParser: + def test_init_commands(self, r): + commands_parser = CommandsParser(r) + assert commands_parser.commands is not None + assert 'get' in commands_parser.commands + + def test_get_keys_predetermined_key_location(self, r): + commands_parser = CommandsParser(r) + args1 = ['GET', 'foo'] + args2 = ['OBJECT', 'encoding', 'foo'] + args3 = ['MGET', 'foo', 'bar', 'foobar'] + assert commands_parser.get_keys(r, *args1) == ['foo'] + assert commands_parser.get_keys(r, *args2) == ['foo'] + assert commands_parser.get_keys(r, *args3) == ['foo', 'bar', 'foobar'] + + @pytest.mark.filterwarnings("ignore:ResponseError") + def test_get_moveable_keys(self, r): + commands_parser = CommandsParser(r) + args1 = ['EVAL', 'return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}', 2, 'key1', + 'key2', 'first', 'second'] + args2 = ['XREAD', 'COUNT', 2, b'STREAMS', 'mystream', 'writers', 0, 0] + args3 = ['ZUNIONSTORE', 'out', 2, 'zset1', 'zset2', 'WEIGHTS', 2, 3] + args4 = ['GEORADIUS', 'Sicily', 15, 37, 200, 'km', 'WITHCOORD', + b'STORE', 'out'] + args5 = ['MEMORY USAGE', 'foo'] + args6 = ['MIGRATE', '192.168.1.34', 6379, "", 0, 5000, b'KEYS', + 'key1', 'key2', 'key3'] + args7 = ['MIGRATE', '192.168.1.34', 6379, "key1", 0, 5000] + args8 = ['STRALGO', 'LCS', 'STRINGS', 'string_a', 'string_b'] + args9 = ['STRALGO', 'LCS', 'KEYS', 'key1', 'key2'] + + assert commands_parser.get_keys( + r, *args1).sort() == ['key1', 'key2'].sort() + assert commands_parser.get_keys( + r, *args2).sort() == ['mystream', 'writers'].sort() + assert commands_parser.get_keys( + r, *args3).sort() == ['out', 'zset1', 'zset2'].sort() + assert commands_parser.get_keys( + r, *args4).sort() == ['Sicily', 'out'].sort() + assert commands_parser.get_keys(r, *args5).sort() == ['foo'].sort() + assert commands_parser.get_keys( + r, *args6).sort() == ['key1', 'key2', 'key3'].sort() + assert commands_parser.get_keys(r, *args7).sort() == ['key1'].sort() + assert commands_parser.get_keys(r, *args8) is None + assert commands_parser.get_keys( + r, *args9).sort() == ['key1', 'key2'].sort() + + def test_get_pubsub_keys(self, r): + commands_parser = CommandsParser(r) + args1 = ['PUBLISH', 'foo', 'bar'] + args2 = ['PUBSUB NUMSUB', 'foo1', 'foo2', 'foo3'] + args3 = ['SUBSCRIBE', 'foo1', 'foo2', 'foo3'] + assert commands_parser.get_keys(r, *args1) == ['foo'] + assert commands_parser.get_keys(r, *args2) == ['foo1', 'foo2', 'foo3'] + assert commands_parser.get_keys(r, *args3) == ['foo1', 'foo2', 'foo3'] diff --git a/tests/test_connection.py b/tests/test_connection.py index fa9a2b0c90..2ca858d263 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,10 +4,11 @@ from redis.exceptions import InvalidResponse, ModuleError from redis.utils import HIREDIS_AVAILABLE -from .conftest import skip_if_server_version_lt +from .conftest import skip_if_server_version_lt, skip_if_cluster_mode @pytest.mark.skipif(HIREDIS_AVAILABLE, reason='PythonParser only') +@skip_if_cluster_mode() def test_invalid_response(r): raw = b'x' parser = r.connection._parser @@ -17,12 +18,14 @@ def test_invalid_response(r): assert str(cm.value) == 'Protocol Error: %r' % raw +@skip_if_cluster_mode() @skip_if_server_version_lt('4.0.0') def test_loaded_modules(r, modclient): assert r.loaded_modules == [] assert 'rejson' in modclient.loaded_modules.keys() +@skip_if_cluster_mode() @skip_if_server_version_lt('4.0.0') def test_loading_external_modules(r, modclient): def inner(): diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 8d2ad041a0..4708057e98 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -7,7 +7,8 @@ from threading import Thread from redis.connection import ssl_available, to_bool -from .conftest import skip_if_server_version_lt, _get_client +from .conftest import skip_if_server_version_lt, skip_if_cluster_mode,\ + _get_client from .test_pubsub import wait_for_message @@ -43,15 +44,15 @@ def test_connection_creation(self): assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs - def test_multiple_connections(self, master_host): - connection_kwargs = {'host': master_host} + def test_multiple_connections(self, master_host, master_port): + connection_kwargs = {'host': master_host, 'port': master_port} pool = self.get_pool(connection_kwargs=connection_kwargs) c1 = pool.get_connection('_') c2 = pool.get_connection('_') assert c1 != c2 - def test_max_connections(self, master_host): - connection_kwargs = {'host': master_host} + def test_max_connections(self, master_host, master_port): + connection_kwargs = {'host': master_host, 'port': master_port} pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) pool.get_connection('_') @@ -59,8 +60,9 @@ def test_max_connections(self, master_host): with pytest.raises(redis.ConnectionError): pool.get_connection('_') - def test_reuse_previously_released_connection(self, master_host): - connection_kwargs = {'host': master_host} + def test_reuse_previously_released_connection(self, master_host, + master_port): + connection_kwargs = {'host': master_host, 'port': master_port} pool = self.get_pool(connection_kwargs=connection_kwargs) c1 = pool.get_connection('_') pool.release(c1) @@ -463,6 +465,7 @@ def get_connection(self, *args, **kwargs): assert pool.get_connection('_').check_hostname is True +@pytest.mark.filterwarnings("ignore:BaseException") class TestConnection: def test_on_connect_error(self): """ @@ -479,6 +482,7 @@ def test_on_connect_error(self): assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock + @skip_if_cluster_mode() @skip_if_server_version_lt('2.8.8') def test_busy_loading_disconnects_socket(self, r): """ @@ -489,6 +493,7 @@ def test_busy_loading_disconnects_socket(self, r): r.execute_command('DEBUG', 'ERROR', 'LOADING fake message') assert not r.connection._sock + @skip_if_cluster_mode() @skip_if_server_version_lt('2.8.8') def test_busy_loading_from_pipeline_immediate_command(self, r): """ @@ -504,6 +509,7 @@ def test_busy_loading_from_pipeline_immediate_command(self, r): assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock + @skip_if_cluster_mode() @skip_if_server_version_lt('2.8.8') def test_busy_loading_from_pipeline(self, r): """ @@ -519,6 +525,7 @@ def test_busy_loading_from_pipeline(self, r): assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock + @pytest.mark.filterwarnings("ignore:ResponseError") @skip_if_server_version_lt('2.8.8') def test_read_only_error(self, r): "READONLY errors get turned in ReadOnlyError exceptions" @@ -560,6 +567,7 @@ def test_connect_invalid_password_supplied(self, r): r.execute_command('DEBUG', 'ERROR', 'ERR invalid password') +@skip_if_cluster_mode() class TestMultiConnectionClient: @pytest.fixture() def r(self, request): @@ -573,6 +581,7 @@ def test_multi_connection_command(self, r): assert r.get('a') == b'123' +@skip_if_cluster_mode() class TestHealthCheck: interval = 60 diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 706654f89f..955735338b 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -91,6 +91,7 @@ def test_basic_command(self, r): r.set('hello', 'world') +@pytest.mark.filterwarnings("ignore:BaseException") class TestInvalidUserInput: def test_boolean_fails(self, r): with pytest.raises(redis.DataError): diff --git a/tests/test_json.py b/tests/test_json.py index 83fbf28669..c0b4d9ee4c 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,7 +1,7 @@ import pytest import redis from redis.commands.json.path import Path -from .conftest import skip_ifmodversion_lt +from .conftest import skip_ifmodversion_lt, skip_if_cluster_mode @pytest.fixture @@ -10,226 +10,207 @@ def client(modclient): return modclient -@pytest.mark.redismod -def test_json_setbinarykey(client): - d = {"hello": "world", b"some": "value"} - with pytest.raises(TypeError): - client.json().set("somekey", Path.rootPath(), d) - assert client.json().set("somekey", Path.rootPath(), d, decode_keys=True) - - -@pytest.mark.redismod -def test_json_setgetdeleteforget(client): - assert client.json().set("foo", Path.rootPath(), "bar") - assert client.json().get("foo") == "bar" - assert client.json().get("baz") is None - assert client.json().delete("foo") == 1 - assert client.json().forget("foo") == 0 # second delete - assert client.exists("foo") == 0 - - -@pytest.mark.redismod -def test_justaget(client): - client.json().set("foo", Path.rootPath(), "bar") - assert client.json().get("foo") == "bar" - - -@pytest.mark.redismod -def test_json_get_jset(client): - assert client.json().set("foo", Path.rootPath(), "bar") - assert "bar" == client.json().get("foo") - assert client.json().get("baz") is None - assert 1 == client.json().delete("foo") - assert client.exists("foo") == 0 - - -@pytest.mark.redismod -def test_nonascii_setgetdelete(client): - assert client.json().set("notascii", Path.rootPath(), - "hyvää-élève") is True - assert "hyvää-élève" == client.json().get("notascii", no_escape=True) - assert 1 == client.json().delete("notascii") - assert client.exists("notascii") == 0 - - -@pytest.mark.redismod -def test_jsonsetexistentialmodifiersshouldsucceed(client): - obj = {"foo": "bar"} - assert client.json().set("obj", Path.rootPath(), obj) - - # Test that flags prevent updates when conditions are unmet - assert client.json().set("obj", Path("foo"), "baz", nx=True) is None - assert client.json().set("obj", Path("qaz"), "baz", xx=True) is None - - # Test that flags allow updates when conditions are met - assert client.json().set("obj", Path("foo"), "baz", xx=True) - assert client.json().set("obj", Path("qaz"), "baz", nx=True) - - # Test that flags are mutually exlusive - with pytest.raises(Exception): - client.json().set("obj", Path("foo"), "baz", nx=True, xx=True) - - -@pytest.mark.redismod -def test_mgetshouldsucceed(client): - client.json().set("1", Path.rootPath(), 1) - client.json().set("2", Path.rootPath(), 2) - r = client.json().mget(Path.rootPath(), "1", "2") - e = [1, 2] - assert e == r - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release -def test_clearShouldSucceed(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 1 == client.json().clear("arr", Path.rootPath()) - assert [] == client.json().get("arr") - - -@pytest.mark.redismod -def test_typeshouldsucceed(client): - client.json().set("1", Path.rootPath(), 1) - assert b"integer" == client.json().type("1") - - -@pytest.mark.redismod -def test_numincrbyshouldsucceed(client): - client.json().set("num", Path.rootPath(), 1) - assert 2 == client.json().numincrby("num", Path.rootPath(), 1) - assert 2.5 == client.json().numincrby("num", Path.rootPath(), 0.5) - assert 1.25 == client.json().numincrby("num", Path.rootPath(), -1.25) - - -@pytest.mark.redismod -def test_nummultbyshouldsucceed(client): - client.json().set("num", Path.rootPath(), 1) - assert 2 == client.json().nummultby("num", Path.rootPath(), 2) - assert 5 == client.json().nummultby("num", Path.rootPath(), 2.5) - assert 2.5 == client.json().nummultby("num", Path.rootPath(), 0.5) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release -def test_toggleShouldSucceed(client): - client.json().set("bool", Path.rootPath(), False) - assert client.json().toggle("bool", Path.rootPath()) - assert not client.json().toggle("bool", Path.rootPath()) - # check non-boolean value - client.json().set("num", Path.rootPath(), 1) - with pytest.raises(redis.exceptions.ResponseError): - client.json().toggle("num", Path.rootPath()) - - -@pytest.mark.redismod -def test_strappendshouldsucceed(client): - client.json().set("str", Path.rootPath(), "foo") - assert 6 == client.json().strappend("str", "bar", Path.rootPath()) - assert "foobar" == client.json().get("str", Path.rootPath()) - - -@pytest.mark.redismod -def test_debug(client): - client.json().set("str", Path.rootPath(), "foo") - assert 24 == client.json().debug("str", Path.rootPath()) - - -@pytest.mark.redismod -def test_strlenshouldsucceed(client): - client.json().set("str", Path.rootPath(), "foo") - assert 3 == client.json().strlen("str", Path.rootPath()) - client.json().strappend("str", "bar", Path.rootPath()) - assert 6 == client.json().strlen("str", Path.rootPath()) - - -@pytest.mark.redismod -def test_arrappendshouldsucceed(client): - client.json().set("arr", Path.rootPath(), [1]) - assert 2 == client.json().arrappend("arr", Path.rootPath(), 2) - assert 4 == client.json().arrappend("arr", Path.rootPath(), 3, 4) - assert 7 == client.json().arrappend("arr", Path.rootPath(), *[5, 6, 7]) - - -@pytest.mark.redismod -def testArrIndexShouldSucceed(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 1 == client.json().arrindex("arr", Path.rootPath(), 1) - assert -1 == client.json().arrindex("arr", Path.rootPath(), 1, 2) - - -@pytest.mark.redismod -def test_arrinsertshouldsucceed(client): - client.json().set("arr", Path.rootPath(), [0, 4]) - assert 5 - -client.json().arrinsert( - "arr", - Path.rootPath(), - 1, - *[ +@skip_if_cluster_mode() +class TestJson: + @pytest.mark.redismod + def test_json_setbinarykey(self, client): + d = {"hello": "world", b"some": "value"} + with pytest.raises(TypeError): + client.json().set("somekey", Path.rootPath(), d) + assert client.json().set("somekey", Path.rootPath(), d, + decode_keys=True) + + @pytest.mark.redismod + def test_json_setgetdeleteforget(self, client): + assert client.json().set("foo", Path.rootPath(), "bar") + assert client.json().get("foo") == "bar" + assert client.json().get("baz") is None + assert client.json().delete("foo") == 1 + assert client.json().forget("foo") == 0 # second delete + assert client.exists("foo") == 0 + + @pytest.mark.redismod + def test_justaget(self, client): + client.json().set("foo", Path.rootPath(), "bar") + assert client.json().get("foo") == "bar" + + @pytest.mark.redismod + def test_json_get_jset(self, client): + assert client.json().set("foo", Path.rootPath(), "bar") + assert "bar" == client.json().get("foo") + assert client.json().get("baz") is None + assert 1 == client.json().delete("foo") + assert client.exists("foo") == 0 + + @pytest.mark.redismod + def test_nonascii_setgetdelete(self, client): + assert client.json().set("notascii", Path.rootPath(), + "hyvää-élève") is True + assert "hyvää-élève" == client.json().get("notascii", no_escape=True) + assert 1 == client.json().delete("notascii") + assert client.exists("notascii") == 0 + + @pytest.mark.redismod + def test_jsonsetexistentialmodifiersshouldsucceed(self, client): + obj = {"foo": "bar"} + assert client.json().set("obj", Path.rootPath(), obj) + + # Test that flags prevent updates when conditions are unmet + assert client.json().set("obj", Path("foo"), "baz", nx=True) is None + assert client.json().set("obj", Path("qaz"), "baz", xx=True) is None + + # Test that flags allow updates when conditions are met + assert client.json().set("obj", Path("foo"), "baz", xx=True) + assert client.json().set("obj", Path("qaz"), "baz", nx=True) + + # Test that flags are mutually exlusive + with pytest.raises(Exception): + client.json().set("obj", Path("foo"), "baz", nx=True, xx=True) + + @pytest.mark.redismod + def test_mgetshouldsucceed(self, client): + client.json().set("1", Path.rootPath(), 1) + client.json().set("2", Path.rootPath(), 2) + r = client.json().mget(Path.rootPath(), "1", "2") + e = [1, 2] + assert e == r + + @pytest.mark.redismod + @skip_ifmodversion_lt("99.99.99", + "ReJSON") # todo: update after the release + def test_clearShouldSucceed(self, client): + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 1 == client.json().clear("arr", Path.rootPath()) + assert [] == client.json().get("arr") + + @pytest.mark.redismod + def test_typeshouldsucceed(self, client): + client.json().set("1", Path.rootPath(), 1) + assert b"integer" == client.json().type("1") + + @pytest.mark.redismod + def test_numincrbyshouldsucceed(self, client): + client.json().set("num", Path.rootPath(), 1) + assert 2 == client.json().numincrby("num", Path.rootPath(), 1) + assert 2.5 == client.json().numincrby("num", Path.rootPath(), 0.5) + assert 1.25 == client.json().numincrby("num", Path.rootPath(), -1.25) + + @pytest.mark.redismod + def test_nummultbyshouldsucceed(self, client): + client.json().set("num", Path.rootPath(), 1) + assert 2 == client.json().nummultby("num", Path.rootPath(), 2) + assert 5 == client.json().nummultby("num", Path.rootPath(), 2.5) + assert 2.5 == client.json().nummultby("num", Path.rootPath(), 0.5) + + @pytest.mark.redismod + @skip_ifmodversion_lt("99.99.99", + "ReJSON") # todo: update after the release + def test_toggleShouldSucceed(self, client): + client.json().set("bool", Path.rootPath(), False) + assert client.json().toggle("bool", Path.rootPath()) + assert not client.json().toggle("bool", Path.rootPath()) + # check non-boolean value + client.json().set("num", Path.rootPath(), 1) + with pytest.raises(redis.exceptions.ResponseError): + client.json().toggle("num", Path.rootPath()) + + @pytest.mark.redismod + def test_strappendshouldsucceed(self, client): + client.json().set("str", Path.rootPath(), "foo") + assert 6 == client.json().strappend("str", "bar", Path.rootPath()) + assert "foobar" == client.json().get("str", Path.rootPath()) + + @pytest.mark.redismod + def test_debug(self, client): + client.json().set("str", Path.rootPath(), "foo") + assert 24 == client.json().debug("str", Path.rootPath()) + + @pytest.mark.redismod + def test_strlenshouldsucceed(self, client): + client.json().set("str", Path.rootPath(), "foo") + assert 3 == client.json().strlen("str", Path.rootPath()) + client.json().strappend("str", "bar", Path.rootPath()) + assert 6 == client.json().strlen("str", Path.rootPath()) + + @pytest.mark.redismod + def test_arrappendshouldsucceed(self, client): + client.json().set("arr", Path.rootPath(), [1]) + assert 2 == client.json().arrappend("arr", Path.rootPath(), 2) + assert 4 == client.json().arrappend("arr", Path.rootPath(), 3, 4) + assert 7 == client.json().arrappend("arr", Path.rootPath(), *[5, 6, 7]) + + @pytest.mark.redismod + def testArrIndexShouldSucceed(self, client): + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 1 == client.json().arrindex("arr", Path.rootPath(), 1) + assert -1 == client.json().arrindex("arr", Path.rootPath(), 1, 2) + + @pytest.mark.redismod + def test_arrinsertshouldsucceed(self, client): + client.json().set("arr", Path.rootPath(), [0, 4]) + assert 5 - -client.json().arrinsert( + "arr", + Path.rootPath(), 1, - 2, - 3, - ] - ) - assert [0, 1, 2, 3, 4] == client.json().get("arr") - - -@pytest.mark.redismod -def test_arrlenshouldsucceed(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 5 == client.json().arrlen("arr", Path.rootPath()) - - -@pytest.mark.redismod -def test_arrpopshouldsucceed(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 4 == client.json().arrpop("arr", Path.rootPath(), 4) - assert 3 == client.json().arrpop("arr", Path.rootPath(), -1) - assert 2 == client.json().arrpop("arr", Path.rootPath()) - assert 0 == client.json().arrpop("arr", Path.rootPath(), 0) - assert [1] == client.json().get("arr") - - -@pytest.mark.redismod -def test_arrtrimshouldsucceed(client): - client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) - assert 3 == client.json().arrtrim("arr", Path.rootPath(), 1, 3) - assert [1, 2, 3] == client.json().get("arr") - - -@pytest.mark.redismod -def test_respshouldsucceed(client): - obj = {"foo": "bar", "baz": 1, "qaz": True} - client.json().set("obj", Path.rootPath(), obj) - assert b"bar" == client.json().resp("obj", Path("foo")) - assert 1 == client.json().resp("obj", Path("baz")) - assert client.json().resp("obj", Path("qaz")) - - -@pytest.mark.redismod -def test_objkeysshouldsucceed(client): - obj = {"foo": "bar", "baz": "qaz"} - client.json().set("obj", Path.rootPath(), obj) - keys = client.json().objkeys("obj", Path.rootPath()) - keys.sort() - exp = list(obj.keys()) - exp.sort() - assert exp == keys - - -@pytest.mark.redismod -def test_objlenshouldsucceed(client): - obj = {"foo": "bar", "baz": "qaz"} - client.json().set("obj", Path.rootPath(), obj) - assert len(obj) == client.json().objlen("obj", Path.rootPath()) - - -# @pytest.mark.pipeline -# @pytest.mark.redismod -# def test_pipelineshouldsucceed(client): -# p = client.json().pipeline() -# p.set("foo", Path.rootPath(), "bar") -# p.get("foo") -# p.delete("foo") -# assert [True, "bar", 1] == p.execute() -# assert client.keys() == [] -# assert client.get("foo") is None + *[ + 1, + 2, + 3, + ] + ) + assert [0, 1, 2, 3, 4] == client.json().get("arr") + + @pytest.mark.redismod + def test_arrlenshouldsucceed(self, client): + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 5 == client.json().arrlen("arr", Path.rootPath()) + + @pytest.mark.redismod + def test_arrpopshouldsucceed(self, client): + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 4 == client.json().arrpop("arr", Path.rootPath(), 4) + assert 3 == client.json().arrpop("arr", Path.rootPath(), -1) + assert 2 == client.json().arrpop("arr", Path.rootPath()) + assert 0 == client.json().arrpop("arr", Path.rootPath(), 0) + assert [1] == client.json().get("arr") + + @pytest.mark.redismod + def test_arrtrimshouldsucceed(self, client): + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 3 == client.json().arrtrim("arr", Path.rootPath(), 1, 3) + assert [1, 2, 3] == client.json().get("arr") + + @pytest.mark.redismod + def test_respshouldsucceed(self, client): + obj = {"foo": "bar", "baz": 1, "qaz": True} + client.json().set("obj", Path.rootPath(), obj) + assert b"bar" == client.json().resp("obj", Path("foo")) + assert 1 == client.json().resp("obj", Path("baz")) + assert client.json().resp("obj", Path("qaz")) + + @pytest.mark.redismod + def test_objkeysshouldsucceed(self, client): + obj = {"foo": "bar", "baz": "qaz"} + client.json().set("obj", Path.rootPath(), obj) + keys = client.json().objkeys("obj", Path.rootPath()) + keys.sort() + exp = list(obj.keys()) + exp.sort() + assert exp == keys + + @pytest.mark.redismod + def test_objlenshouldsucceed(self, client): + obj = {"foo": "bar", "baz": "qaz"} + client.json().set("obj", Path.rootPath(), obj) + assert len(obj) == client.json().objlen("obj", Path.rootPath()) + + # @pytest.mark.pipeline + # @pytest.mark.redismod + # def test_pipelineshouldsucceed(client): + # p = client.json().pipeline() + # p.set("foo", Path.rootPath(), "bar") + # p.get("foo") + # p.delete("foo") + # assert [True, "bar", 1] == p.execute() + # assert client.keys() == [] + # assert client.get("foo") is None diff --git a/tests/test_lock.py b/tests/test_lock.py index fa76385221..ab62dfc820 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -4,9 +4,10 @@ from redis.exceptions import LockError, LockNotOwnedError from redis.client import Redis from redis.lock import Lock -from .conftest import _get_client +from .conftest import _get_client, skip_if_cluster_mode +@skip_if_cluster_mode() class TestLock: @pytest.fixture() def r_decoded(self, request): @@ -220,6 +221,7 @@ def test_reacquiring_lock_no_longer_owned_raises_error(self, r): lock.reacquire() +@skip_if_cluster_mode() class TestLockClassSelection: def test_lock_class_argument(self, r): class MyLock: diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 1013202f22..5d065c9206 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -1,6 +1,7 @@ -from .conftest import wait_for_command +from .conftest import wait_for_command, skip_if_cluster_mode +@skip_if_cluster_mode() class TestMonitor: def test_wait_command_not_found(self, r): "Make sure the wait_for_command func works when command is not found" diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 2d27c4e8bb..a298af39c7 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -30,12 +30,12 @@ def r(self, request): request=request, single_connection_client=False) - def test_close_connection_in_child(self, master_host): + def test_close_connection_in_child(self, master_host, master_port): """ A connection owned by a parent and closed by a child doesn't destroy the file descriptors so a parent can still use it. """ - conn = Connection(host=master_host) + conn = Connection(host=master_host, port=master_port) conn.send_command('ping') assert conn.read_response() == b'PONG' @@ -56,12 +56,12 @@ def target(conn): conn.send_command('ping') assert conn.read_response() == b'PONG' - def test_close_connection_in_parent(self, master_host): + def test_close_connection_in_parent(self, master_host, master_port): """ A connection owned by a parent is unusable by a child if the parent (the owning process) closes the connection. """ - conn = Connection(host=master_host) + conn = Connection(host=master_host, port=master_port) conn.send_command('ping') assert conn.read_response() == b'PONG' @@ -84,12 +84,13 @@ def target(conn, ev): assert proc.exitcode == 0 @pytest.mark.parametrize('max_connections', [1, 2, None]) - def test_pool(self, max_connections, master_host): + def test_pool(self, max_connections, master_host, master_port): """ A child will create its own connections when using a pool created by a parent. """ - pool = ConnectionPool.from_url('redis://{}'.format(master_host), + pool = ConnectionPool.from_url('redis://{}:{}'.format(master_host, + master_port), max_connections=max_connections) conn = pool.get_connection('ping') @@ -119,12 +120,14 @@ def target(pool): assert conn.read_response() == b'PONG' @pytest.mark.parametrize('max_connections', [1, 2, None]) - def test_close_pool_in_main(self, max_connections, master_host): + def test_close_pool_in_main(self, max_connections, master_host, + master_port): """ A child process that uses the same pool as its parent isn't affected when the parent disconnects all connections within the pool. """ - pool = ConnectionPool.from_url('redis://{}'.format(master_host), + pool = ConnectionPool.from_url('redis://{}:{}'.format(master_host, + master_port), max_connections=max_connections) conn = pool.get_connection('ping') diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 08bd40bacd..8fadf46bf1 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,7 +1,8 @@ import pytest import redis -from .conftest import wait_for_command, skip_if_server_version_lt +from .conftest import wait_for_command, skip_if_server_version_lt, \ + skip_if_cluster_mode class TestPipeline: @@ -59,6 +60,7 @@ def test_pipeline_no_transaction(self, r): assert r['b'] == b'b1' assert r['c'] == b'c1' + @skip_if_cluster_mode() def test_pipeline_no_transaction_watch(self, r): r['a'] = 0 @@ -70,6 +72,7 @@ def test_pipeline_no_transaction_watch(self, r): pipe.set('a', int(a) + 1) assert pipe.execute() == [True] + @skip_if_cluster_mode() def test_pipeline_no_transaction_watch_failure(self, r): r['a'] = 0 @@ -129,6 +132,7 @@ def test_exec_error_raised(self, r): assert pipe.set('z', 'zzz').execute() == [True] assert r['z'] == b'zzz' + @skip_if_cluster_mode() def test_transaction_with_empty_error_command(self, r): """ Commands with custom EMPTY_ERROR functionality return their default @@ -143,6 +147,7 @@ def test_transaction_with_empty_error_command(self, r): assert result[1] == [] assert result[2] + @skip_if_cluster_mode() def test_pipeline_with_empty_error_command(self, r): """ Commands with custom EMPTY_ERROR functionality return their default @@ -171,6 +176,7 @@ def test_parse_error_raised(self, r): assert pipe.set('z', 'zzz').execute() == [True] assert r['z'] == b'zzz' + @skip_if_cluster_mode() def test_parse_error_raised_transaction(self, r): with r.pipeline() as pipe: pipe.multi() @@ -186,6 +192,7 @@ def test_parse_error_raised_transaction(self, r): assert pipe.set('z', 'zzz').execute() == [True] assert r['z'] == b'zzz' + @skip_if_cluster_mode() def test_watch_succeed(self, r): r['a'] = 1 r['b'] = 2 @@ -203,6 +210,7 @@ def test_watch_succeed(self, r): assert pipe.execute() == [True] assert not pipe.watching + @skip_if_cluster_mode() def test_watch_failure(self, r): r['a'] = 1 r['b'] = 2 @@ -217,6 +225,7 @@ def test_watch_failure(self, r): assert not pipe.watching + @skip_if_cluster_mode() def test_watch_failure_in_empty_transaction(self, r): r['a'] = 1 r['b'] = 2 @@ -230,6 +239,7 @@ def test_watch_failure_in_empty_transaction(self, r): assert not pipe.watching + @skip_if_cluster_mode() def test_unwatch(self, r): r['a'] = 1 r['b'] = 2 @@ -242,6 +252,7 @@ def test_unwatch(self, r): pipe.get('a') assert pipe.execute() == [b'1'] + @skip_if_cluster_mode() def test_watch_exec_no_unwatch(self, r): r['a'] = 1 r['b'] = 2 @@ -262,6 +273,7 @@ def test_watch_exec_no_unwatch(self, r): unwatch_command = wait_for_command(r, m, 'UNWATCH') assert unwatch_command is None, "should not send UNWATCH" + @skip_if_cluster_mode() def test_watch_reset_unwatch(self, r): r['a'] = 1 @@ -276,6 +288,7 @@ def test_watch_reset_unwatch(self, r): assert unwatch_command is not None assert unwatch_command['command'] == 'UNWATCH' + @skip_if_cluster_mode() def test_transaction_callable(self, r): r['a'] = 1 r['b'] = 2 @@ -300,6 +313,7 @@ def my_transaction(pipe): assert result == [True] assert r['c'] == b'4' + @skip_if_cluster_mode() def test_transaction_callable_returns_value_from_callable(self, r): def callback(pipe): # No need to do anything here since we only want the return value @@ -354,6 +368,7 @@ def test_pipeline_with_bitfield(self, r): assert pipe == pipe2 assert response == [True, [0, 0, 15, 15, 14], b'1'] + @skip_if_cluster_mode() @skip_if_server_version_lt('2.0.0') def test_pipeline_discard(self, r): diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 6a4f0aafa4..ebb96de58b 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -7,7 +7,8 @@ import redis from redis.exceptions import ConnectionError -from .conftest import _get_client, skip_if_server_version_lt +from .conftest import _get_client, skip_if_cluster_mode, \ + skip_if_server_version_lt def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): @@ -119,6 +120,7 @@ def test_resubscribe_to_channels_on_reconnection(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'channel') self._test_resubscribe_on_reconnection(**kwargs) + @skip_if_cluster_mode() def test_resubscribe_to_patterns_on_reconnection(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') self._test_resubscribe_on_reconnection(**kwargs) @@ -173,6 +175,7 @@ def test_subscribe_property_with_channels(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'channel') self._test_subscribed_property(**kwargs) + @skip_if_cluster_mode() def test_subscribe_property_with_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') self._test_subscribed_property(**kwargs) @@ -216,6 +219,7 @@ def test_sub_unsub_resub_channels(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'channel') self._test_sub_unsub_resub(**kwargs) + @skip_if_cluster_mode() def test_sub_unsub_resub_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') self._test_sub_unsub_resub(**kwargs) @@ -303,6 +307,7 @@ def test_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message('message', 'foo', 'test message') + @skip_if_cluster_mode() def test_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) p.psubscribe(**{'f*': self.message_handler}) @@ -322,6 +327,9 @@ def test_unicode_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message('message', channel, 'test message') + @skip_if_cluster_mode() + # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html + # #known-limitations-with-pubsub def test_unicode_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) pattern = 'uni' + chr(4456) + '*' @@ -397,6 +405,7 @@ def test_channel_publish(self, r): self.channel, self.data) + @skip_if_cluster_mode() def test_pattern_publish(self, r): p = r.pubsub() p.psubscribe(self.pattern) @@ -493,7 +502,7 @@ def test_pubsub_numsub(self, r): assert wait_for_message(p3)['type'] == 'subscribe' channels = [(b'foo', 1), (b'bar', 2), (b'baz', 3)] - assert channels == r.pubsub_numsub('foo', 'bar', 'baz') + assert r.pubsub_numsub('foo', 'bar', 'baz') == channels @skip_if_server_version_lt('2.8.0') def test_pubsub_numpat(self, r): @@ -525,6 +534,7 @@ def test_send_pubsub_ping_message(self, r): pattern=None) +@skip_if_cluster_mode() class TestPubSubConnectionKilled: @skip_if_server_version_lt('3.0.0') diff --git a/tests/test_scripting.py b/tests/test_scripting.py index c3c2094d4a..46a684e36d 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -1,7 +1,9 @@ import pytest from redis import exceptions - +from .conftest import ( + skip_if_cluster_mode, +) multiply_script = """ local value = redis.call('GET', KEYS[1]) @@ -20,6 +22,7 @@ """ +@skip_if_cluster_mode() class TestScripting: @pytest.fixture(autouse=True) def reset_scripts(self, r): diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 54cf262c43..7f66603085 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -5,6 +5,7 @@ from redis import exceptions from redis.sentinel import (Sentinel, SentinelConnectionPool, MasterNotFoundError, SlaveNotFoundError) +from .conftest import skip_if_cluster_mode import redis.sentinel @@ -13,6 +14,7 @@ def master_ip(master_host): yield socket.gethostbyname(master_host) +@skip_if_cluster_mode() class SentinelTestClient: def __init__(self, cluster, id): self.cluster = cluster @@ -36,6 +38,24 @@ def execute_command(self, *args, **kwargs): return bool_ok +@pytest.fixture() +def cluster(request, master_ip): + def teardown(): + redis.sentinel.Redis = saved_Redis + + cluster = SentinelTestCluster(ip=master_ip) + saved_Redis = redis.sentinel.Redis + redis.sentinel.Redis = cluster.client + request.addfinalizer(teardown) + return cluster + + +@pytest.fixture() +def sentinel(request, cluster): + return Sentinel([('foo', 26379), ('bar', 26379)]) + + +@skip_if_cluster_mode() class SentinelTestCluster: def __init__(self, servisentinel_ce_name='mymaster', ip='127.0.0.1', port=6379): @@ -64,156 +84,129 @@ def timeout_if_down(self, node): def client(self, host, port, **kwargs): return SentinelTestClient(self, (host, port)) - -@pytest.fixture() -def cluster(request, master_ip): - def teardown(): - redis.sentinel.Redis = saved_Redis - cluster = SentinelTestCluster(ip=master_ip) - saved_Redis = redis.sentinel.Redis - redis.sentinel.Redis = cluster.client - request.addfinalizer(teardown) - return cluster - - -@pytest.fixture() -def sentinel(request, cluster): - return Sentinel([('foo', 26379), ('bar', 26379)]) - - -def test_discover_master(sentinel, master_ip): - address = sentinel.discover_master('mymaster') - assert address == (master_ip, 6379) - - -def test_discover_master_error(sentinel): - with pytest.raises(MasterNotFoundError): - sentinel.discover_master('xxx') - - -def test_discover_master_sentinel_down(cluster, sentinel, master_ip): - # Put first sentinel 'foo' down - cluster.nodes_down.add(('foo', 26379)) - address = sentinel.discover_master('mymaster') - assert address == (master_ip, 6379) - # 'bar' is now first sentinel - assert sentinel.sentinels[0].id == ('bar', 26379) - - -def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip): - # Put first sentinel 'foo' down - cluster.nodes_timeout.add(('foo', 26379)) - address = sentinel.discover_master('mymaster') - assert address == (master_ip, 6379) - # 'bar' is now first sentinel - assert sentinel.sentinels[0].id == ('bar', 26379) - - -def test_master_min_other_sentinels(cluster, master_ip): - sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1) - # min_other_sentinels - with pytest.raises(MasterNotFoundError): - sentinel.discover_master('mymaster') - cluster.master['num-other-sentinels'] = 2 - address = sentinel.discover_master('mymaster') - assert address == (master_ip, 6379) - - -def test_master_odown(cluster, sentinel): - cluster.master['is_odown'] = True - with pytest.raises(MasterNotFoundError): - sentinel.discover_master('mymaster') - - -def test_master_sdown(cluster, sentinel): - cluster.master['is_sdown'] = True - with pytest.raises(MasterNotFoundError): - sentinel.discover_master('mymaster') - - -def test_discover_slaves(cluster, sentinel): - assert sentinel.discover_slaves('mymaster') == [] - - cluster.slaves = [ - {'ip': 'slave0', 'port': 1234, 'is_odown': False, 'is_sdown': False}, - {'ip': 'slave1', 'port': 1234, 'is_odown': False, 'is_sdown': False}, - ] - assert sentinel.discover_slaves('mymaster') == [ - ('slave0', 1234), ('slave1', 1234)] - - # slave0 -> ODOWN - cluster.slaves[0]['is_odown'] = True - assert sentinel.discover_slaves('mymaster') == [ - ('slave1', 1234)] - - # slave1 -> SDOWN - cluster.slaves[1]['is_sdown'] = True - assert sentinel.discover_slaves('mymaster') == [] - - cluster.slaves[0]['is_odown'] = False - cluster.slaves[1]['is_sdown'] = False - - # node0 -> DOWN - cluster.nodes_down.add(('foo', 26379)) - assert sentinel.discover_slaves('mymaster') == [ - ('slave0', 1234), ('slave1', 1234)] - cluster.nodes_down.clear() - - # node0 -> TIMEOUT - cluster.nodes_timeout.add(('foo', 26379)) - assert sentinel.discover_slaves('mymaster') == [ - ('slave0', 1234), ('slave1', 1234)] - - -def test_master_for(cluster, sentinel, master_ip): - master = sentinel.master_for('mymaster', db=9) - assert master.ping() - assert master.connection_pool.master_address == (master_ip, 6379) - - # Use internal connection check - master = sentinel.master_for('mymaster', db=9, check_connection=True) - assert master.ping() - - -def test_slave_for(cluster, sentinel): - cluster.slaves = [ - {'ip': '127.0.0.1', 'port': 6379, - 'is_odown': False, 'is_sdown': False}, - ] - slave = sentinel.slave_for('mymaster', db=9) - assert slave.ping() - - -def test_slave_for_slave_not_found_error(cluster, sentinel): - cluster.master['is_odown'] = True - slave = sentinel.slave_for('mymaster', db=9) - with pytest.raises(SlaveNotFoundError): - slave.ping() - - -def test_slave_round_robin(cluster, sentinel, master_ip): - cluster.slaves = [ - {'ip': 'slave0', 'port': 6379, 'is_odown': False, 'is_sdown': False}, - {'ip': 'slave1', 'port': 6379, 'is_odown': False, 'is_sdown': False}, - ] - pool = SentinelConnectionPool('mymaster', sentinel) - rotator = pool.rotate_slaves() - assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) - assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) - # Fallback to master - assert next(rotator) == (master_ip, 6379) - with pytest.raises(SlaveNotFoundError): - next(rotator) - - -def test_ckquorum(cluster, sentinel): - assert sentinel.sentinel_ckquorum("mymaster") - - -def test_flushconfig(cluster, sentinel): - assert sentinel.sentinel_flushconfig() - - -def test_reset(cluster, sentinel): - cluster.master['is_odown'] = True - assert sentinel.sentinel_reset('mymaster') + def test_discover_master(sentinel, master_ip): + address = sentinel.discover_master('mymaster') + assert address == (master_ip, 6379) + + def test_discover_master_error(sentinel): + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('xxx') + + def test_discover_master_sentinel_down(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_down.add(('foo', 26379)) + address = sentinel.discover_master('mymaster') + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ('bar', 26379) + + def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_timeout.add(('foo', 26379)) + address = sentinel.discover_master('mymaster') + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ('bar', 26379) + + def test_master_min_other_sentinels(cluster, master_ip): + sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1) + # min_other_sentinels + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('mymaster') + cluster.master['num-other-sentinels'] = 2 + address = sentinel.discover_master('mymaster') + assert address == (master_ip, 6379) + + def test_master_odown(cluster, sentinel): + cluster.master['is_odown'] = True + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('mymaster') + + def test_master_sdown(cluster, sentinel): + cluster.master['is_sdown'] = True + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('mymaster') + + def test_discover_slaves(cluster, sentinel): + assert sentinel.discover_slaves('mymaster') == [] + + cluster.slaves = [ + {'ip': 'slave0', 'port': 1234, 'is_odown': False, + 'is_sdown': False}, + {'ip': 'slave1', 'port': 1234, 'is_odown': False, + 'is_sdown': False}, + ] + assert sentinel.discover_slaves('mymaster') == [ + ('slave0', 1234), ('slave1', 1234)] + + # slave0 -> ODOWN + cluster.slaves[0]['is_odown'] = True + assert sentinel.discover_slaves('mymaster') == [ + ('slave1', 1234)] + + # slave1 -> SDOWN + cluster.slaves[1]['is_sdown'] = True + assert sentinel.discover_slaves('mymaster') == [] + + cluster.slaves[0]['is_odown'] = False + cluster.slaves[1]['is_sdown'] = False + + # node0 -> DOWN + cluster.nodes_down.add(('foo', 26379)) + assert sentinel.discover_slaves('mymaster') == [ + ('slave0', 1234), ('slave1', 1234)] + cluster.nodes_down.clear() + + # node0 -> TIMEOUT + cluster.nodes_timeout.add(('foo', 26379)) + assert sentinel.discover_slaves('mymaster') == [ + ('slave0', 1234), ('slave1', 1234)] + + def test_master_for(cluster, sentinel, master_ip): + master = sentinel.master_for('mymaster', db=9) + assert master.ping() + assert master.connection_pool.master_address == (master_ip, 6379) + + # Use internal connection check + master = sentinel.master_for('mymaster', db=9, check_connection=True) + assert master.ping() + + def test_slave_for(cluster, sentinel): + cluster.slaves = [ + {'ip': '127.0.0.1', 'port': 6379, + 'is_odown': False, 'is_sdown': False}, + ] + slave = sentinel.slave_for('mymaster', db=9) + assert slave.ping() + + def test_slave_for_slave_not_found_error(cluster, sentinel): + cluster.master['is_odown'] = True + slave = sentinel.slave_for('mymaster', db=9) + with pytest.raises(SlaveNotFoundError): + slave.ping() + + def test_slave_round_robin(cluster, sentinel, master_ip): + cluster.slaves = [ + {'ip': 'slave0', 'port': 6379, 'is_odown': False, + 'is_sdown': False}, + {'ip': 'slave1', 'port': 6379, 'is_odown': False, + 'is_sdown': False}, + ] + pool = SentinelConnectionPool('mymaster', sentinel) + rotator = pool.rotate_slaves() + assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) + assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) + # Fallback to master + assert next(rotator) == (master_ip, 6379) + with pytest.raises(SlaveNotFoundError): + next(rotator) + + def test_ckquorum(cluster, sentinel): + assert sentinel.sentinel_ckquorum("mymaster") + + def test_flushconfig(cluster, sentinel): + assert sentinel.sentinel_flushconfig() + + def test_reset(cluster, sentinel): + cluster.master['is_odown'] = True + assert sentinel.sentinel_reset('mymaster') diff --git a/tox.ini b/tox.ini index 67b7e7575d..e46a0c47f5 100644 --- a/tox.ini +++ b/tox.ini @@ -74,6 +74,21 @@ image = redisfab/lots-of-pythons volumes = bind:rw:{toxinidir}:/data +[docker:redis_cluster] +name = redis_cluster +image = barshaul/redis-py:6.2.6-cluster +healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',16379)) else False" +ports = + 16379:16379/tcp + 16380:16380/tcp + 16381:16381/tcp + 16382:16382/tcp + 16383:16383/tcp + 16384:16384/tcp +volumes = + bind:rw:{toxinidir}/docker/cluster/redis.conf:/redis.conf + + [testenv] deps = -r {toxinidir}/dev_requirements.txt docker = @@ -82,6 +97,7 @@ docker = sentinel_1 sentinel_2 sentinel_3 + redis_cluster redismod extras = hiredis: hiredis @@ -98,6 +114,7 @@ docker = sentinel_1 sentinel_2 sentinel_3 + redis_cluster redismod lots-of-pythons commands = /usr/bin/echo From 40a5893f4f4903f4c04bf438beb5be7a9cdb319e Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Thu, 28 Oct 2021 10:59:50 +0300 Subject: [PATCH 02/22] Moved timesteries tests to a class --- tests/test_timeseries.py | 1107 +++++++++++++++++++------------------- 1 file changed, 547 insertions(+), 560 deletions(-) diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index b2df3feda5..e066a7f385 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -1,7 +1,7 @@ import pytest import time from time import sleep -from .conftest import skip_ifmodversion_lt +from .conftest import skip_ifmodversion_lt, skip_if_cluster_mode @pytest.fixture @@ -10,584 +10,571 @@ def client(modclient): return modclient -@pytest.mark.redismod -def testCreate(client): - assert client.ts().create(1) - assert client.ts().create(2, retention_msecs=5) - assert client.ts().create(3, labels={"Redis": "Labs"}) - assert client.ts().create(4, retention_msecs=20, labels={"Time": "Series"}) - info = client.ts().info(4) - assert 20 == info.retention_msecs - assert "Series" == info.labels["Time"] - - # Test for a chunk size of 128 Bytes - assert client.ts().create("time-serie-1", chunk_size=128) - info = client.ts().info("time-serie-1") - assert 128, info.chunk_size - - -@pytest.mark.redismod -@skip_ifmodversion_lt("1.4.0", "timeseries") -def testCreateDuplicatePolicy(client): - # Test for duplicate policy - for duplicate_policy in ["block", "last", "first", "min", "max"]: - ts_name = "time-serie-ooo-{0}".format(duplicate_policy) - assert client.ts().create(ts_name, duplicate_policy=duplicate_policy) - info = client.ts().info(ts_name) - assert duplicate_policy == info.duplicate_policy - - -@pytest.mark.redismod -def testAlter(client): - assert client.ts().create(1) - assert 0 == client.ts().info(1).retention_msecs - assert client.ts().alter(1, retention_msecs=10) - assert {} == client.ts().info(1).labels - assert 10, client.ts().info(1).retention_msecs - assert client.ts().alter(1, labels={"Time": "Series"}) - assert "Series" == client.ts().info(1).labels["Time"] - assert 10 == client.ts().info(1).retention_msecs - - -# pipe = client.ts().pipeline() -# assert pipe.create(2) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("1.4.0", "timeseries") -def testAlterDiplicatePolicy(client): - assert client.ts().create(1) - info = client.ts().info(1) - assert info.duplicate_policy is None - assert client.ts().alter(1, duplicate_policy="min") - info = client.ts().info(1) - assert "min" == info.duplicate_policy - - -@pytest.mark.redismod -def testAdd(client): - assert 1 == client.ts().add(1, 1, 1) - assert 2 == client.ts().add(2, 2, 3, retention_msecs=10) - assert 3 == client.ts().add(3, 3, 2, labels={"Redis": "Labs"}) - assert 4 == client.ts().add( - 4, 4, 2, retention_msecs=10, labels={"Redis": "Labs", "Time": "Series"} - ) - assert round(time.time()) == \ - round(float(client.ts().add(5, "*", 1)) / 1000) - - info = client.ts().info(4) - assert 10 == info.retention_msecs - assert "Labs" == info.labels["Redis"] - - # Test for a chunk size of 128 Bytes on TS.ADD - assert client.ts().add("time-serie-1", 1, 10.0, chunk_size=128) - info = client.ts().info("time-serie-1") - assert 128 == info.chunk_size - - -@pytest.mark.redismod -@skip_ifmodversion_lt("1.4.0", "timeseries") -def testAddDuplicatePolicy(client): - - # Test for duplicate policy BLOCK - assert 1 == client.ts().add("time-serie-add-ooo-block", 1, 5.0) - with pytest.raises(Exception): - client.ts().add( - "time-serie-add-ooo-block", - 1, - 5.0, - duplicate_policy="block" - ) - - # Test for duplicate policy LAST - assert 1 == client.ts().add("time-serie-add-ooo-last", 1, 5.0) - assert 1 == client.ts().add( - "time-serie-add-ooo-last", 1, 10.0, duplicate_policy="last" - ) - assert 10.0 == client.ts().get("time-serie-add-ooo-last")[1] - - # Test for duplicate policy FIRST - assert 1 == client.ts().add("time-serie-add-ooo-first", 1, 5.0) - assert 1 == client.ts().add( - "time-serie-add-ooo-first", 1, 10.0, duplicate_policy="first" - ) - assert 5.0 == client.ts().get("time-serie-add-ooo-first")[1] - - # Test for duplicate policy MAX - assert 1 == client.ts().add("time-serie-add-ooo-max", 1, 5.0) - assert 1 == client.ts().add( - "time-serie-add-ooo-max", 1, 10.0, duplicate_policy="max" - ) - assert 10.0 == client.ts().get("time-serie-add-ooo-max")[1] - - # Test for duplicate policy MIN - assert 1 == client.ts().add("time-serie-add-ooo-min", 1, 5.0) - assert 1 == client.ts().add( - "time-serie-add-ooo-min", 1, 10.0, duplicate_policy="min" - ) - assert 5.0 == client.ts().get("time-serie-add-ooo-min")[1] - - -@pytest.mark.redismod -def testMAdd(client): - client.ts().create("a") - assert [1, 2, 3] == \ - client.ts().madd([("a", 1, 5), ("a", 2, 10), ("a", 3, 15)]) - - -@pytest.mark.redismod -def testIncrbyDecrby(client): - for _ in range(100): - assert client.ts().incrby(1, 1) - sleep(0.001) - assert 100 == client.ts().get(1)[1] - for _ in range(100): - assert client.ts().decrby(1, 1) - sleep(0.001) - assert 0 == client.ts().get(1)[1] - - assert client.ts().incrby(2, 1.5, timestamp=5) - assert (5, 1.5) == client.ts().get(2) - assert client.ts().incrby(2, 2.25, timestamp=7) - assert (7, 3.75) == client.ts().get(2) - assert client.ts().decrby(2, 1.5, timestamp=15) - assert (15, 2.25) == client.ts().get(2) - - # Test for a chunk size of 128 Bytes on TS.INCRBY - assert client.ts().incrby("time-serie-1", 10, chunk_size=128) - info = client.ts().info("time-serie-1") - assert 128 == info.chunk_size - - # Test for a chunk size of 128 Bytes on TS.DECRBY - assert client.ts().decrby("time-serie-2", 10, chunk_size=128) - info = client.ts().info("time-serie-2") - assert 128 == info.chunk_size - - -@pytest.mark.redismod -def testCreateAndDeleteRule(client): - # test rule creation - time = 100 - client.ts().create(1) - client.ts().create(2) - client.ts().createrule(1, 2, "avg", 100) - for i in range(50): - client.ts().add(1, time + i * 2, 1) - client.ts().add(1, time + i * 2 + 1, 2) - client.ts().add(1, time * 2, 1.5) - assert round(client.ts().get(2)[1], 5) == 1.5 - info = client.ts().info(1) - assert info.rules[0][1] == 100 - - # test rule deletion - client.ts().deleterule(1, 2) - info = client.ts().info(1) - assert not info.rules - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "timeseries") -def testDelRange(client): - try: - client.ts().delete("test", 0, 100) - except Exception as e: - assert e.__str__() != "" - - for i in range(100): - client.ts().add(1, i, i % 7) - assert 22 == client.ts().delete(1, 0, 21) - assert [] == client.ts().range(1, 0, 21) - assert [(22, 1.0)] == client.ts().range(1, 22, 22) - - -@pytest.mark.redismod -def testRange(client): - for i in range(100): - client.ts().add(1, i, i % 7) - assert 100 == len(client.ts().range(1, 0, 200)) - for i in range(100): - client.ts().add(1, i + 200, i % 7) - assert 200 == len(client.ts().range(1, 0, 500)) - # last sample isn't returned - assert 20 == len( - client.ts().range( - 1, +@skip_if_cluster_mode() +class TestTimeseries: + @pytest.mark.redismod + def testCreate(self, client): + assert client.ts().create(1) + assert client.ts().create(2, retention_msecs=5) + assert client.ts().create(3, labels={"Redis": "Labs"}) + assert client.ts().create(4, retention_msecs=20, + labels={"Time": "Series"}) + info = client.ts().info(4) + assert 20 == info.retention_msecs + assert "Series" == info.labels["Time"] + + # Test for a chunk size of 128 Bytes + assert client.ts().create("time-serie-1", chunk_size=128) + info = client.ts().info("time-serie-1") + assert 128, info.chunk_size + + @pytest.mark.redismod + @skip_ifmodversion_lt("1.4.0", "timeseries") + def testCreateDuplicatePolicy(self, client): + # Test for duplicate policy + for duplicate_policy in ["block", "last", "first", "min", "max"]: + ts_name = "time-serie-ooo-{0}".format(duplicate_policy) + assert client.ts().create(ts_name, + duplicate_policy=duplicate_policy) + info = client.ts().info(ts_name) + assert duplicate_policy == info.duplicate_policy + + @pytest.mark.redismod + def testAlter(self, client): + assert client.ts().create(1) + assert 0 == client.ts().info(1).retention_msecs + assert client.ts().alter(1, retention_msecs=10) + assert {} == client.ts().info(1).labels + assert 10, client.ts().info(1).retention_msecs + assert client.ts().alter(1, labels={"Time": "Series"}) + assert "Series" == client.ts().info(1).labels["Time"] + assert 10 == client.ts().info(1).retention_msecs + + # pipe = client.ts().pipeline() + # assert pipe.create(2) + + @pytest.mark.redismod + @skip_ifmodversion_lt("1.4.0", "timeseries") + def testAlterDiplicatePolicy(self, client): + assert client.ts().create(1) + info = client.ts().info(1) + assert info.duplicate_policy is None + assert client.ts().alter(1, duplicate_policy="min") + info = client.ts().info(1) + assert "min" == info.duplicate_policy + + @pytest.mark.redismod + def testAdd(self, client): + assert 1 == client.ts().add(1, 1, 1) + assert 2 == client.ts().add(2, 2, 3, retention_msecs=10) + assert 3 == client.ts().add(3, 3, 2, labels={"Redis": "Labs"}) + assert 4 == client.ts().add( + 4, 4, 2, retention_msecs=10, + labels={"Redis": "Labs", "Time": "Series"} + ) + assert round(time.time()) == \ + round(float(client.ts().add(5, "*", 1)) / 1000) + + info = client.ts().info(4) + assert 10 == info.retention_msecs + assert "Labs" == info.labels["Redis"] + + # Test for a chunk size of 128 Bytes on TS.ADD + assert client.ts().add("time-serie-1", 1, 10.0, chunk_size=128) + info = client.ts().info("time-serie-1") + assert 128 == info.chunk_size + + @pytest.mark.redismod + @skip_ifmodversion_lt("1.4.0", "timeseries") + def testAddDuplicatePolicy(self, client): + + # Test for duplicate policy BLOCK + assert 1 == client.ts().add("time-serie-add-ooo-block", 1, 5.0) + with pytest.raises(Exception): + client.ts().add( + "time-serie-add-ooo-block", + 1, + 5.0, + duplicate_policy="block" + ) + + # Test for duplicate policy LAST + assert 1 == client.ts().add("time-serie-add-ooo-last", 1, 5.0) + assert 1 == client.ts().add( + "time-serie-add-ooo-last", 1, 10.0, duplicate_policy="last" + ) + assert 10.0 == client.ts().get("time-serie-add-ooo-last")[1] + + # Test for duplicate policy FIRST + assert 1 == client.ts().add("time-serie-add-ooo-first", 1, 5.0) + assert 1 == client.ts().add( + "time-serie-add-ooo-first", 1, 10.0, duplicate_policy="first" + ) + assert 5.0 == client.ts().get("time-serie-add-ooo-first")[1] + + # Test for duplicate policy MAX + assert 1 == client.ts().add("time-serie-add-ooo-max", 1, 5.0) + assert 1 == client.ts().add( + "time-serie-add-ooo-max", 1, 10.0, duplicate_policy="max" + ) + assert 10.0 == client.ts().get("time-serie-add-ooo-max")[1] + + # Test for duplicate policy MIN + assert 1 == client.ts().add("time-serie-add-ooo-min", 1, 5.0) + assert 1 == client.ts().add( + "time-serie-add-ooo-min", 1, 10.0, duplicate_policy="min" + ) + assert 5.0 == client.ts().get("time-serie-add-ooo-min")[1] + + @pytest.mark.redismod + def testMAdd(self, client): + client.ts().create("a") + assert [1, 2, 3] == \ + client.ts().madd([("a", 1, 5), ("a", 2, 10), ("a", 3, 15)]) + + @pytest.mark.redismod + def testIncrbyDecrby(self, client): + for _ in range(100): + assert client.ts().incrby(1, 1) + sleep(0.001) + assert 100 == client.ts().get(1)[1] + for _ in range(100): + assert client.ts().decrby(1, 1) + sleep(0.001) + assert 0 == client.ts().get(1)[1] + + assert client.ts().incrby(2, 1.5, timestamp=5) + assert (5, 1.5) == client.ts().get(2) + assert client.ts().incrby(2, 2.25, timestamp=7) + assert (7, 3.75) == client.ts().get(2) + assert client.ts().decrby(2, 1.5, timestamp=15) + assert (15, 2.25) == client.ts().get(2) + + # Test for a chunk size of 128 Bytes on TS.INCRBY + assert client.ts().incrby("time-serie-1", 10, chunk_size=128) + info = client.ts().info("time-serie-1") + assert 128 == info.chunk_size + + # Test for a chunk size of 128 Bytes on TS.DECRBY + assert client.ts().decrby("time-serie-2", 10, chunk_size=128) + info = client.ts().info("time-serie-2") + assert 128 == info.chunk_size + + @pytest.mark.redismod + def testCreateAndDeleteRule(self, client): + # test rule creation + time = 100 + client.ts().create(1) + client.ts().create(2) + client.ts().createrule(1, 2, "avg", 100) + for i in range(50): + client.ts().add(1, time + i * 2, 1) + client.ts().add(1, time + i * 2 + 1, 2) + client.ts().add(1, time * 2, 1.5) + assert round(client.ts().get(2)[1], 5) == 1.5 + info = client.ts().info(1) + assert info.rules[0][1] == 100 + + # test rule deletion + client.ts().deleterule(1, 2) + info = client.ts().info(1) + assert not info.rules + + @pytest.mark.redismod + @skip_ifmodversion_lt("99.99.99", "timeseries") + def testDelRange(self, client): + try: + client.ts().delete("test", 0, 100) + except Exception as e: + assert e.__str__() != "" + + for i in range(100): + client.ts().add(1, i, i % 7) + assert 22 == client.ts().delete(1, 0, 21) + assert [] == client.ts().range(1, 0, 21) + assert [(22, 1.0)] == client.ts().range(1, 22, 22) + + @pytest.mark.redismod + def testRange(self, client): + for i in range(100): + client.ts().add(1, i, i % 7) + assert 100 == len(client.ts().range(1, 0, 200)) + for i in range(100): + client.ts().add(1, i + 200, i % 7) + assert 200 == len(client.ts().range(1, 0, 500)) + # last sample isn't returned + assert 20 == len( + client.ts().range( + 1, + 0, + 500, + aggregation_type="avg", + bucket_size_msec=10 + ) + ) + assert 10 == len(client.ts().range(1, 0, 500, count=10)) + + @pytest.mark.redismod + @skip_ifmodversion_lt("99.99.99", "timeseries") + def testRangeAdvanced(self, client): + for i in range(100): + client.ts().add(1, i, i % 7) + client.ts().add(1, i + 200, i % 7) + + assert 2 == len( + client.ts().range( + 1, + 0, + 500, + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + ) + assert [(0, 10.0), (10, 1.0)] == client.ts().range( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + ) + assert [(-5, 5.0), (5, 6.0)] == client.ts().range( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 + ) + + @pytest.mark.redismod + @skip_ifmodversion_lt("99.99.99", "timeseries") + def testRevRange(self, client): + for i in range(100): + client.ts().add(1, i, i % 7) + assert 100 == len(client.ts().range(1, 0, 200)) + for i in range(100): + client.ts().add(1, i + 200, i % 7) + assert 200 == len(client.ts().range(1, 0, 500)) + # first sample isn't returned + assert 20 == len( + client.ts().revrange( + 1, + 0, + 500, + aggregation_type="avg", + bucket_size_msec=10 + ) + ) + assert 10 == len(client.ts().revrange(1, 0, 500, count=10)) + assert 2 == len( + client.ts().revrange( + 1, + 0, + 500, + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + ) + assert [(10, 1.0), (0, 10.0)] == client.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + ) + assert [(1, 10.0), (-9, 1.0)] == client.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + ) + + @pytest.mark.redismod + def testMultiRange(self, client): + client.ts().create(1, labels={"Test": "This", "team": "ny"}) + client.ts().create( + 2, + labels={"Test": "This", "Taste": "That", "team": "sf"} + ) + for i in range(100): + client.ts().add(1, i, i % 7) + client.ts().add(2, i, i % 11) + + res = client.ts().mrange(0, 200, filters=["Test=This"]) + assert 2 == len(res) + assert 100 == len(res[0]["1"][1]) + + res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) + + for i in range(100): + client.ts().add(1, i + 200, i % 7) + res = client.ts().mrange( 0, 500, + filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 ) - ) - assert 10 == len(client.ts().range(1, 0, 500, count=10)) - + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + + # test withlabels + assert {} == res[0]["1"][0] + res = client.ts().mrange(0, 200, filters=["Test=This"], + with_labels=True) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + + @pytest.mark.redismod + @skip_ifmodversion_lt("99.99.99", "timeseries") + def testMultiRangeAdvanced(self, client): + client.ts().create(1, labels={"Test": "This", "team": "ny"}) + client.ts().create( + 2, + labels={"Test": "This", "Taste": "That", "team": "sf"} + ) + for i in range(100): + client.ts().add(1, i, i % 7) + client.ts().add(2, i, i % 11) -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "timeseries") -def testRangeAdvanced(client): - for i in range(100): - client.ts().add(1, i, i % 7) - client.ts().add(1, i + 200, i % 7) + # test with selected labels + res = client.ts().mrange( + 0, + 200, + filters=["Test=This"], + select_labels=["team"] + ) + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] - assert 2 == len( - client.ts().range( - 1, + # test with filterby + res = client.ts().mrange( 0, - 500, + 200, + filters=["Test=This"], filter_by_ts=[i for i in range(10, 20)], filter_by_min_value=1, filter_by_max_value=2, ) - ) - assert [(0, 10.0), (10, 1.0)] == client.ts().range( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" - ) - assert [(-5, 5.0), (5, 6.0)] == client.ts().range( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 - ) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "timeseries") -def testRevRange(client): - for i in range(100): - client.ts().add(1, i, i % 7) - assert 100 == len(client.ts().range(1, 0, 200)) - for i in range(100): - client.ts().add(1, i + 200, i % 7) - assert 200 == len(client.ts().range(1, 0, 500)) - # first sample isn't returned - assert 20 == len( - client.ts().revrange( - 1, + assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] + + # test groupby + res = client.ts().mrange( + 0, + 3, + filters=["Test=This"], + groupby="Test", + reduce="sum" + ) + assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][ + 1] + res = client.ts().mrange( + 0, + 3, + filters=["Test=This"], + groupby="Test", + reduce="max" + ) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][ + 1] + res = client.ts().mrange( + 0, + 3, + filters=["Test=This"], + groupby="team", + reduce="min") + assert 2 == len(res) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] + + # test align + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [(-5, 5.0), (5, 6.0)] == res[0]["1"][1] + + @pytest.mark.redismod + @skip_ifmodversion_lt("99.99.99", "timeseries") + def testMultiReverseRange(self, client): + client.ts().create(1, labels={"Test": "This", "team": "ny"}) + client.ts().create( + 2, + labels={"Test": "This", "Taste": "That", "team": "sf"} + ) + for i in range(100): + client.ts().add(1, i, i % 7) + client.ts().add(2, i, i % 11) + + res = client.ts().mrange(0, 200, filters=["Test=This"]) + assert 2 == len(res) + assert 100 == len(res[0]["1"][1]) + + res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) + + for i in range(100): + client.ts().add(1, i + 200, i % 7) + res = client.ts().mrevrange( 0, 500, + filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 ) - ) - assert 10 == len(client.ts().revrange(1, 0, 500, count=10)) - assert 2 == len( - client.ts().revrange( - 1, + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + assert {} == res[0]["1"][0] + + # test withlabels + res = client.ts().mrevrange( 0, - 500, + 200, + filters=["Test=This"], + with_labels=True + ) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + + # test with selected labels + res = client.ts().mrevrange( + 0, + 200, + filters=["Test=This"], select_labels=["team"] + ) + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] + + # test filterby + res = client.ts().mrevrange( + 0, + 200, + filters=["Test=This"], filter_by_ts=[i for i in range(10, 20)], filter_by_min_value=1, filter_by_max_value=2, ) - ) - assert [(10, 1.0), (0, 10.0)] == client.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" - ) - assert [(1, 10.0), (-9, 1.0)] == client.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 - ) - - -@pytest.mark.redismod -def testMultiRange(client): - client.ts().create(1, labels={"Test": "This", "team": "ny"}) - client.ts().create( - 2, - labels={"Test": "This", "Taste": "That", "team": "sf"} - ) - for i in range(100): - client.ts().add(1, i, i % 7) - client.ts().add(2, i, i % 11) - - res = client.ts().mrange(0, 200, filters=["Test=This"]) - assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) - - res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) - - for i in range(100): - client.ts().add(1, i + 200, i % 7) - res = client.ts().mrange( - 0, - 500, - filters=["Test=This"], - aggregation_type="avg", - bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) - - # test withlabels - assert {} == res[0]["1"][0] - res = client.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "timeseries") -def testMultiRangeAdvanced(client): - client.ts().create(1, labels={"Test": "This", "team": "ny"}) - client.ts().create( - 2, - labels={"Test": "This", "Taste": "That", "team": "sf"} - ) - for i in range(100): - client.ts().add(1, i, i % 7) - client.ts().add(2, i, i % 11) - - # test with selected labels - res = client.ts().mrange( - 0, - 200, - filters=["Test=This"], - select_labels=["team"] - ) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] - - # test with filterby - res = client.ts().mrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] - - # test groupby - res = client.ts().mrange( - 0, - 3, - filters=["Test=This"], - groupby="Test", - reduce="sum" - ) - assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] - res = client.ts().mrange( - 0, - 3, - filters=["Test=This"], - groupby="Test", - reduce="max" - ) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] - res = client.ts().mrange( - 0, - 3, - filters=["Test=This"], - groupby="team", - reduce="min") - assert 2 == len(res) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] - - # test align - res = client.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] - res = client.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=5, - ) - assert [(-5, 5.0), (5, 6.0)] == res[0]["1"][1] - - -@pytest.mark.redismod -@skip_ifmodversion_lt("99.99.99", "timeseries") -def testMultiReverseRange(client): - client.ts().create(1, labels={"Test": "This", "team": "ny"}) - client.ts().create( - 2, - labels={"Test": "This", "Taste": "That", "team": "sf"} - ) - for i in range(100): - client.ts().add(1, i, i % 7) - client.ts().add(2, i, i % 11) - - res = client.ts().mrange(0, 200, filters=["Test=This"]) - assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) - - res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) - - for i in range(100): - client.ts().add(1, i + 200, i % 7) - res = client.ts().mrevrange( - 0, - 500, - filters=["Test=This"], - aggregation_type="avg", - bucket_size_msec=10 - ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) - assert {} == res[0]["1"][0] - - # test withlabels - res = client.ts().mrevrange( - 0, - 200, - filters=["Test=This"], - with_labels=True - ) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] - - # test with selected labels - res = client.ts().mrevrange( - 0, - 200, - filters=["Test=This"], select_labels=["team"] - ) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] - - # test filterby - res = client.ts().mrevrange( - 0, - 200, - filters=["Test=This"], - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] - - # test groupby - res = client.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" - ) - assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] - res = client.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="max" - ) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] - res = client.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="team", reduce="min" - ) - assert 2 == len(res) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] - - # test align - res = client.ts().mrevrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] - res = client.ts().mrevrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=1, - ) - assert [(1, 10.0), (-9, 1.0)] == res[0]["1"][1] - - -@pytest.mark.redismod -def testGet(client): - name = "test" - client.ts().create(name) - assert client.ts().get(name) is None - client.ts().add(name, 2, 3) - assert 2 == client.ts().get(name)[0] - client.ts().add(name, 3, 4) - assert 4 == client.ts().get(name)[1] - - -@pytest.mark.redismod -def testMGet(client): - client.ts().create(1, labels={"Test": "This"}) - client.ts().create(2, labels={"Test": "This", "Taste": "That"}) - act_res = client.ts().mget(["Test=This"]) - exp_res = [{"1": [{}, None, None]}, {"2": [{}, None, None]}] - assert act_res == exp_res - client.ts().add(1, "*", 15) - client.ts().add(2, "*", 25) - res = client.ts().mget(["Test=This"]) - assert 15 == res[0]["1"][2] - assert 25 == res[1]["2"][2] - res = client.ts().mget(["Taste=That"]) - assert 25 == res[0]["2"][2] - - # test with_labels - assert {} == res[0]["2"][0] - res = client.ts().mget(["Taste=That"], with_labels=True) - assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] - - -@pytest.mark.redismod -def testInfo(client): - client.ts().create( - 1, - retention_msecs=5, - labels={"currentLabel": "currentData"} - ) - info = client.ts().info(1) - assert 5 == info.retention_msecs - assert info.labels["currentLabel"] == "currentData" - - -@pytest.mark.redismod -@skip_ifmodversion_lt("1.4.0", "timeseries") -def testInfoDuplicatePolicy(client): - client.ts().create( - 1, - retention_msecs=5, - labels={"currentLabel": "currentData"} - ) - info = client.ts().info(1) - assert info.duplicate_policy is None - - client.ts().create("time-serie-2", duplicate_policy="min") - info = client.ts().info("time-serie-2") - assert "min" == info.duplicate_policy - - -@pytest.mark.redismod -def testQueryIndex(client): - client.ts().create(1, labels={"Test": "This"}) - client.ts().create(2, labels={"Test": "This", "Taste": "That"}) - assert 2 == len(client.ts().queryindex(["Test=This"])) - assert 1 == len(client.ts().queryindex(["Taste=That"])) - assert [2] == client.ts().queryindex(["Taste=That"]) - - -# -# @pytest.mark.redismod -# @pytest.mark.pipeline -# def testPipeline(client): -# pipeline = client.ts().pipeline() -# pipeline.create("with_pipeline") -# for i in range(100): -# pipeline.add("with_pipeline", i, 1.1 * i) -# pipeline.execute() - -# info = client.ts().info("with_pipeline") -# assert info.lastTimeStamp == 99 -# assert info.total_samples == 100 -# assert client.ts().get("with_pipeline")[1] == 99 * 1.1 - - -@pytest.mark.redismod -def testUncompressed(client): - client.ts().create("compressed") - client.ts().create("uncompressed", uncompressed=True) - compressed_info = client.ts().info("compressed") - uncompressed_info = client.ts().info("uncompressed") - assert compressed_info.memory_usage != uncompressed_info.memory_usage + assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + + # test groupby + res = client.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][ + 1] + res = client.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][ + 1] + res = client.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] + + # test align + res = client.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] + res = client.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=1, + ) + assert [(1, 10.0), (-9, 1.0)] == res[0]["1"][1] + + @pytest.mark.redismod + def testGet(self, client): + name = "test" + client.ts().create(name) + assert client.ts().get(name) is None + client.ts().add(name, 2, 3) + assert 2 == client.ts().get(name)[0] + client.ts().add(name, 3, 4) + assert 4 == client.ts().get(name)[1] + + @pytest.mark.redismod + def testMGet(self, client): + client.ts().create(1, labels={"Test": "This"}) + client.ts().create(2, labels={"Test": "This", "Taste": "That"}) + act_res = client.ts().mget(["Test=This"]) + exp_res = [{"1": [{}, None, None]}, {"2": [{}, None, None]}] + assert act_res == exp_res + client.ts().add(1, "*", 15) + client.ts().add(2, "*", 25) + res = client.ts().mget(["Test=This"]) + assert 15 == res[0]["1"][2] + assert 25 == res[1]["2"][2] + res = client.ts().mget(["Taste=That"]) + assert 25 == res[0]["2"][2] + + # test with_labels + assert {} == res[0]["2"][0] + res = client.ts().mget(["Taste=That"], with_labels=True) + assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + + @pytest.mark.redismod + def testInfo(self, client): + client.ts().create( + 1, + retention_msecs=5, + labels={"currentLabel": "currentData"} + ) + info = client.ts().info(1) + assert 5 == info.retention_msecs + assert info.labels["currentLabel"] == "currentData" + + @pytest.mark.redismod + @skip_ifmodversion_lt("1.4.0", "timeseries") + def testInfoDuplicatePolicy(self, client): + client.ts().create( + 1, + retention_msecs=5, + labels={"currentLabel": "currentData"} + ) + info = client.ts().info(1) + assert info.duplicate_policy is None + + client.ts().create("time-serie-2", duplicate_policy="min") + info = client.ts().info("time-serie-2") + assert "min" == info.duplicate_policy + + @pytest.mark.redismod + def testQueryIndex(self, client): + client.ts().create(1, labels={"Test": "This"}) + client.ts().create(2, labels={"Test": "This", "Taste": "That"}) + assert 2 == len(client.ts().queryindex(["Test=This"])) + assert 1 == len(client.ts().queryindex(["Taste=That"])) + assert [2] == client.ts().queryindex(["Taste=That"]) + + # + # @pytest.mark.redismod + # @pytest.mark.pipeline + # def testPipeline(client): + # pipeline = client.ts().pipeline() + # pipeline.create("with_pipeline") + # for i in range(100): + # pipeline.add("with_pipeline", i, 1.1 * i) + # pipeline.execute() + + # info = client.ts().info("with_pipeline") + # assert info.lastTimeStamp == 99 + # assert info.total_samples == 100 + # assert client.ts().get("with_pipeline")[1] == 99 * 1.1 + + @pytest.mark.redismod + def testUncompressed(self, client): + client.ts().create("compressed") + client.ts().create("uncompressed", uncompressed=True) + compressed_info = client.ts().info("compressed") + uncompressed_info = client.ts().info("uncompressed") + assert compressed_info.memory_usage != uncompressed_info.memory_usage From 45fb0302a032e01df52a638a069c4219bccdb567 Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 28 Oct 2021 12:45:35 +0300 Subject: [PATCH 03/22] starting to clean the docs (#1657) --- CONTRIBUTING.md | 22 ++-- README.md | 328 +++++++++++++++++++++++------------------------- 2 files changed, 171 insertions(+), 179 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 31170f3718..af067e7fdf 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,7 +10,7 @@ community contributions! You may already know what you want to contribute \-- a fix for a bug you encountered, or a new feature your team wants to use. -If you don\'t know what to contribute, keep an open mind! Improving +If you don't know what to contribute, keep an open mind! Improving documentation, bug triaging, and writing tutorials are all examples of helpful contributions that mean less work for you. @@ -28,7 +28,7 @@ tutorials: ## Getting Started -Here\'s how to get started with your code contribution: +Here's how to get started with your code contribution: 1. Create your own fork of redis-py 2. Do the changes in your fork @@ -134,19 +134,19 @@ Please try at least versions of Docker. ### Security Vulnerabilities **NOTE**: If you find a security vulnerability, do NOT open an issue. -Email Andy McCurdy () instead. +Email [Redis Open Source ()](mailto:oss@redis.com) instead. In order to determine whether you are dealing with a security issue, ask yourself these two questions: -- Can I access something that\'s not mine, or something I shouldn\'t +- Can I access something that's not mine, or something I shouldn't have access to? - Can I disable something for other people? -If the answer to either of those two questions are \"yes\", then you\'re +If the answer to either of those two questions are *yes*, then you're probably dealing with a security issue. Note that even if you answer -\"no\" to both questions, you may still be dealing with a security -issue, so if you\'re unsure, just email Andy at . +*no* to both questions, you may still be dealing with a security +issue, so if you're unsure, just email [us](mailto:oss@redis.com). ### Everything Else @@ -160,17 +160,17 @@ When filing an issue, make sure to answer these five questions: ## How to Suggest a Feature or Enhancement -If you\'d like to contribute a new feature, make sure you check our +If you'd like to contribute a new feature, make sure you check our issue list to see if someone has already proposed it. Work may already -be under way on the feature you want \-- or we may have rejected a +be under way on the feature you want -- or we may have rejected a feature like it already. -If you don\'t see anything, open a new issue that describes the feature +If you don't see anything, open a new issue that describes the feature you would like and how it should work. ## Code Review Process The core team looks at Pull Requests on a regular basis. We will give feedback as as soon as possible. After feedback, we expect a response -within two weeks. After that time, we may close your PR if it isn\'t +within two weeks. After that time, we may close your PR if it isn't showing any activity. diff --git a/README.md b/README.md index 5cd9b53570..8b42b65d26 100644 --- a/README.md +++ b/README.md @@ -4,21 +4,19 @@ The Python interface to the Redis key-value store. [![CI](https://github.com/redis/redis-py/workflows/CI/badge.svg?branch=master)](https://github.com/redis/redis-py/actions?query=workflow%3ACI+branch%3Amaster) [![docs](https://readthedocs.org/projects/redis-py/badge/?version=stable&style=flat)](https://redis-py.readthedocs.io/en/stable/) +[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE.txt) [![pypi](https://badge.fury.io/py/redis.svg)](https://pypi.org/project/redis/) [![codecov](https://codecov.io/gh/redis/redis-py/branch/master/graph/badge.svg?token=yenl5fzxxr)](https://codecov.io/gh/redis/redis-py) [![Total alerts](https://img.shields.io/lgtm/alerts/g/redis/redis-py.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/redis/redis-py/alerts/) +[Installation](##installation) | [Contributing](##contributing) | [Getting Started](##getting-started) | [Connecting To Redis](##connecting-to-redis) -## Python 2 Compatibility Note +--------------------------------------------- -redis-py 3.5.x will be the last version of redis-py that supports Python -2. The 3.5.x line will continue to get bug fixes and security patches -that support Python 2 until August 1, 2020. redis-py 4.0 will be the -next major version and will require Python 3.5+. ## Installation -redis-py requires a running Redis server. See [Redis\'s +redis-py requires a running Redis server. See [Redis's quickstart](https://redis.io/topics/quickstart) for installation instructions. @@ -45,12 +43,14 @@ $ python setup.py install ## Contributing -Want to contribute a feature, bug report, or report an issue? Check out +Want to contribute a feature, bug fix, or report an issue? Check out our [guide to -contributing](https://github.com/redis/redis-py/blob/master/CONTRIBUTING.rst). +contributing](https://github.com/redis/redis-py/blob/master/CONTRIBUTING.md). ## Getting Started +redis-py supports Python 3.6+. + ``` pycon >>> import redis >>> r = redis.Redis(host='localhost', port=6379, db=0) @@ -61,118 +61,26 @@ b'bar' ``` By default, all responses are returned as bytes in Python -3 and str in Python 2. The user is responsible for -decoding to Python 3 strings or Python 2 unicode objects. +3. If **all** string responses from a client should be decoded, the user -can specify decode_responses=True to -Redis.\_\_init\_\_. In this case, any Redis command that +can specify *decode_responses=True* in +```Redis.__init__```. In this case, any Redis command that returns a string type will be decoded with the encoding specified. -The default encoding is \"utf-8\", but this can be customized with the -encoding argument to the redis.Redis class. +The default encoding is utf-8, but this can be customized by specifiying the +encoding argument for the redis.Redis class. The encoding will be used to automatically encode any -strings passed to commands, such as key names and values. When -decode_responses=True, string data returned from commands -will be decoded with the same encoding. - -## Upgrading from redis-py 2.X to 3.0 - -redis-py 3.0 introduces many new features but required a number of -backwards incompatible changes to be made in the process. This section -attempts to provide an upgrade path for users migrating from 2.X to 3.0. - -### Python Version Support - -redis-py supports Python 3.5+. - -### Client Classes: Redis and StrictRedis - -redis-py 3.0 drops support for the legacy \"Redis\" client class. -\"StrictRedis\" has been renamed to \"Redis\" and an alias named -\"StrictRedis\" is provided so that users previously using -\"StrictRedis\" can continue to run unchanged. - -The 2.X \"Redis\" class provided alternative implementations of a few -commands. This confused users (rightfully so) and caused a number of -support issues. To make things easier going forward, it was decided to -drop support for these alternate implementations and instead focus on a -single client class. - -2.X users that are already using StrictRedis don\'t have to change the -class name. StrictRedis will continue to work for the foreseeable -future. - -2.X users that are using the Redis class will have to make changes if -they use any of the following commands: - -- SETEX: The argument order has changed. The new order is (name, time, - value). -- LREM: The argument order has changed. The new order is (name, num, - value). -- TTL and PTTL: The return value is now always an int and matches the - official Redis command (>0 indicates the timeout, -1 indicates that - the key exists but that it has no expire time set, -2 indicates that - the key does not exist) - -### SSL Connections - -redis-py 3.0 changes the default value of the -ssl_cert_reqs option from None to -\'required\'. See [Issue -1016](https://github.com/redis/redis-py/issues/1016). This change -enforces hostname validation when accepting a cert from a remote SSL -terminator. If the terminator doesn\'t properly set the hostname on the -cert this will cause redis-py 3.0 to raise a ConnectionError. - -This check can be disabled by setting ssl_cert_reqs to -None. Note that doing so removes the security check. Do so -at your own risk. - -Example with hostname verification using a local certificate bundle -(linux): - -``` pycon ->>> import redis ->>> r = redis.Redis(host='xxxxxx.cache.amazonaws.com', port=6379, db=0, - ssl=True, - ssl_ca_certs='/etc/ssl/certs/ca-certificates.crt') ->>> r.set('foo', 'bar') -True ->>> r.get('foo') -b'bar' -``` - -Example with hostname verification using -[certifi](https://pypi.org/project/certifi/): - -``` pycon ->>> import redis, certifi ->>> r = redis.Redis(host='xxxxxx.cache.amazonaws.com', port=6379, db=0, - ssl=True, ssl_ca_certs=certifi.where()) ->>> r.set('foo', 'bar') -True ->>> r.get('foo') -b'bar' -``` +strings passed to commands, such as key names and values. -Example turning off hostname verification (not recommended): -``` pycon ->>> import redis ->>> r = redis.Redis(host='xxxxxx.cache.amazonaws.com', port=6379, db=0, - ssl=True, ssl_cert_reqs=None) ->>> r.set('foo', 'bar') -True ->>> r.get('foo') -b'bar' -``` +-------------------- ### MSET, MSETNX and ZADD These commands all accept a mapping of key/value pairs. In redis-py 2.X -this mapping could be specified as `*args` or as `**kwargs`. Both of +this mapping could be specified as **args* or as `**kwargs`. Both of these styles caused issues when Redis introduced optional flags to ZADD. Relying on `*args` caused issues with the optional argument order, especially in Python 2.7. Relying on `**kwargs` caused potential @@ -229,8 +137,8 @@ supports the Lua-based lock. In doing so, LuaLock has been renamed to Lock. This also means that redis-py Lock objects require Redis server 2.6 or greater. -2.X users that were explicitly referring to \"LuaLock\" will have to now -refer to \"Lock\" instead. +2.X users that were explicitly referring to *LuaLock* will have to now +refer to *Lock* instead. ### Locks as Context Managers @@ -240,7 +148,7 @@ This is more of a bug fix than a backwards incompatible change. However, given an error is now raised where none was before, this might alarm some users. -2.X users should make sure they\'re wrapping their lock code in a +2.X users should make sure they're wrapping their lock code in a try/catch like this: ``` python @@ -259,8 +167,8 @@ to adhere to the official command syntax. There are a few exceptions: - **SELECT**: Not implemented. See the explanation in the Thread Safety section below. -- **DEL**: \'del\' is a reserved keyword in the Python syntax. - Therefore redis-py uses \'delete\' instead. +- **DEL**: *del* is a reserved keyword in the Python syntax. + Therefore redis-py uses *delete* instead. - **MULTI/EXEC**: These are implemented as part of the Pipeline class. The pipeline is wrapped with the MULTI and EXEC statements by default when it is executed, which can be disabled by specifying @@ -273,14 +181,44 @@ to adhere to the official command syntax. There are a few exceptions: PUBLISH from the Redis client (see [this comment on issue #151](https://github.com/redis/redis-py/issues/151#issuecomment-1545015) for details). -- **SCAN/SSCAN/HSCAN/ZSCAN**: The \*SCAN commands are implemented as +- **SCAN/SSCAN/HSCAN/ZSCAN**: The *SCAN commands are implemented as they exist in the Redis documentation. In addition, each command has an equivalent iterator method. These are purely for convenience so - the user doesn\'t have to keep track of the cursor while iterating. + the user doesn't have to keep track of the cursor while iterating. Use the scan_iter/sscan_iter/hscan_iter/zscan_iter methods for this behavior. -## More Detail +## Connecting to Redis + +### Client Classes: Redis and StrictRedis + +redis-py 3.0 drops support for the legacy *Redis* client class. +*StrictRedis* has been renamed to *Redis* and an alias named +*StrictRedis* is provided so that users previously using +*StrictRedis* can continue to run unchanged. + +The 2.X *Redis* class provided alternative implementations of a few +commands. This confused users (rightfully so) and caused a number of +support issues. To make things easier going forward, it was decided to +drop support for these alternate implementations and instead focus on a +single client class. + +2.X users that are already using StrictRedis don\'t have to change the +class name. StrictRedis will continue to work for the foreseeable +future. + +2.X users that are using the Redis class will have to make changes if +they use any of the following commands: + +- SETEX: The argument order has changed. The new order is (name, time, + value). +- LREM: The argument order has changed. The new order is (name, num, + value). +- TTL and PTTL: The return value is now always an int and matches the + official Redis command (>0 indicates the timeout, -1 indicates that + the key exists but that it has no expire time set, -2 indicates that + the key does not exist) + ### Connection Pools @@ -355,7 +293,7 @@ option to a value less than 30. This option also works on any PubSub connection that is created from a client with `health_check_interval` enabled. PubSub users need to ensure -that `get_message()` or `listen()` are called more frequently than +that *get_message()* or `listen()` are called more frequently than `health_check_interval` seconds. It is assumed that most workloads already do this. @@ -363,6 +301,108 @@ If your PubSub use case doesn\'t call `get_message()` or `listen()` frequently, you should call `pubsub.check_health()` explicitly on a regularly basis. +### SSL Connections + +redis-py 3.0 changes the default value of the +ssl_cert_reqs option from None to +\'required\'. See [Issue +1016](https://github.com/redis/redis-py/issues/1016). This change +enforces hostname validation when accepting a cert from a remote SSL +terminator. If the terminator doesn\'t properly set the hostname on the +cert this will cause redis-py 3.0 to raise a ConnectionError. + +This check can be disabled by setting ssl_cert_reqs to +None. Note that doing so removes the security check. Do so +at your own risk. + +Example with hostname verification using a local certificate bundle +(linux): + +``` pycon +>>> import redis +>>> r = redis.Redis(host='xxxxxx.cache.amazonaws.com', port=6379, db=0, + ssl=True, + ssl_ca_certs='/etc/ssl/certs/ca-certificates.crt') +>>> r.set('foo', 'bar') +True +>>> r.get('foo') +b'bar' +``` + +Example with hostname verification using +[certifi](https://pypi.org/project/certifi/): + +``` pycon +>>> import redis, certifi +>>> r = redis.Redis(host='xxxxxx.cache.amazonaws.com', port=6379, db=0, + ssl=True, ssl_ca_certs=certifi.where()) +>>> r.set('foo', 'bar') +True +>>> r.get('foo') +b'bar' +``` + +Example turning off hostname verification (not recommended): + +``` pycon +>>> import redis +>>> r = redis.Redis(host='xxxxxx.cache.amazonaws.com', port=6379, db=0, + ssl=True, ssl_cert_reqs=None) +>>> r.set('foo', 'bar') +True +>>> r.get('foo') +b'bar' +``` + +### Sentinel support + +redis-py can be used together with [Redis +Sentinel](https://redis.io/topics/sentinel) to discover Redis nodes. You +need to have at least one Sentinel daemon running in order to use +redis-py's Sentinel support. + +Connecting redis-py to the Sentinel instance(s) is easy. You can use a +Sentinel connection to discover the master and slaves network addresses: + +``` pycon +>>> from redis.sentinel import Sentinel +>>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) +>>> sentinel.discover_master('mymaster') +('127.0.0.1', 6379) +>>> sentinel.discover_slaves('mymaster') +[('127.0.0.1', 6380)] +``` + +You can also create Redis client connections from a Sentinel instance. +You can connect to either the master (for write operations) or a slave +(for read-only operations). + +``` pycon +>>> master = sentinel.master_for('mymaster', socket_timeout=0.1) +>>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1) +>>> master.set('foo', 'bar') +>>> slave.get('foo') +b'bar' +``` + +The master and slave objects are normal Redis instances with their +connection pool bound to the Sentinel instance. When a Sentinel backed +client attempts to establish a connection, it first queries the Sentinel +servers to determine an appropriate host to connect to. If no server is +found, a MasterNotFoundError or SlaveNotFoundError is raised. Both +exceptions are subclasses of ConnectionError. + +When trying to connect to a slave client, the Sentinel connection pool +will iterate over the list of slaves until it finds one that can be +connected to. If no slaves can be connected to, a connection will be +established with the master. + +See [Guidelines for Redis clients with support for Redis +Sentinel](https://redis.io/topics/sentinel-clients) to learn more about +Redis Sentinel. + +-------------------------- + ### Parsers Parser classes provide a way to control how responses from the Redis @@ -875,52 +915,6 @@ just prior to pipeline execution. [True, 25] ``` -### Sentinel support - -redis-py can be used together with [Redis -Sentinel](https://redis.io/topics/sentinel) to discover Redis nodes. You -need to have at least one Sentinel daemon running in order to use -redis-py\'s Sentinel support. - -Connecting redis-py to the Sentinel instance(s) is easy. You can use a -Sentinel connection to discover the master and slaves network addresses: - -``` pycon ->>> from redis.sentinel import Sentinel ->>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) ->>> sentinel.discover_master('mymaster') -('127.0.0.1', 6379) ->>> sentinel.discover_slaves('mymaster') -[('127.0.0.1', 6380)] -``` - -You can also create Redis client connections from a Sentinel instance. -You can connect to either the master (for write operations) or a slave -(for read-only operations). - -``` pycon ->>> master = sentinel.master_for('mymaster', socket_timeout=0.1) ->>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1) ->>> master.set('foo', 'bar') ->>> slave.get('foo') -b'bar' -``` - -The master and slave objects are normal Redis instances with their -connection pool bound to the Sentinel instance. When a Sentinel backed -client attempts to establish a connection, it first queries the Sentinel -servers to determine an appropriate host to connect to. If no server is -found, a MasterNotFoundError or SlaveNotFoundError is raised. Both -exceptions are subclasses of ConnectionError. - -When trying to connect to a slave client, the Sentinel connection pool -will iterate over the list of slaves until it finds one that can be -connected to. If no slaves can be connected to, a connection will be -established with the master. - -See [Guidelines for Redis clients with support for Redis -Sentinel](https://redis.io/topics/sentinel-clients) to learn more about -Redis Sentinel. ### Scan Iterators @@ -1167,18 +1161,16 @@ to learn more about Redis Cluster. ### Author -redis-py is developed and maintained by Andy McCurdy -(). It can be found here: - +redis-py is developed and maintained by [Redis Inc](https://redis.com). It can be found [here]( +https://github.com/redis/redis-py), or downloaded from [pypi](https://pypi.org/project/redis/). Special thanks to: +- Andy McCurdy () the original author of redis-py. - Ludovico Magnocavallo, author of the original Python Redis client, from which some of the socket code is still used. - Alexander Solovyov for ideas on the generic response callback system. - Paul Hubbard for initial packaging support. -### Sponsored by - [![Redis](./docs/logo-redis.png)](https://www.redis.com) From f38deece8bab3495a81a7f9fca70b7a5c806d361 Mon Sep 17 00:00:00 2001 From: Chayim Date: Thu, 28 Oct 2021 12:46:04 +0300 Subject: [PATCH 04/22] Adding vulture for static analysis (#1655) * Adding vulture for static analysis Removing dead code found previously by vulture in local runs. --- dev_requirements.txt | 3 ++- redis/connection.py | 1 - redis/features.py | 5 ----- tasks.py | 2 +- tox.ini | 29 +++++++++++------------------ whitelist.py | 12 ++++++++++++ 6 files changed, 26 insertions(+), 26 deletions(-) delete mode 100644 redis/features.py create mode 100644 whitelist.py diff --git a/dev_requirements.txt b/dev_requirements.txt index d3f91fef3d..aa9d8f9eee 100644 --- a/dev_requirements.txt +++ b/dev_requirements.txt @@ -3,4 +3,5 @@ pytest==6.2.5 tox==3.24.4 tox-docker==3.1.0 invoke==1.6.0 -pytest-cov>=3.0.0 \ No newline at end of file +pytest-cov>=3.0.0 +vulture>=2.3.0 diff --git a/redis/connection.py b/redis/connection.py index e1ad6ea7f2..f2becbeba7 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -757,7 +757,6 @@ def can_read(self, timeout=0): sock = self._sock if not sock: self.connect() - sock = self._sock return self._parser.can_read(timeout) def read_response(self): diff --git a/redis/features.py b/redis/features.py deleted file mode 100644 index a96bac7c77..0000000000 --- a/redis/features.py +++ /dev/null @@ -1,5 +0,0 @@ -try: - import hiredis # noqa - HIREDIS_AVAILABLE = True -except ImportError: - HIREDIS_AVAILABLE = False diff --git a/tasks.py b/tasks.py index 44b652908d..2631c702d4 100644 --- a/tasks.py +++ b/tasks.py @@ -23,7 +23,7 @@ def devenv(c): @task def linters(c): """Run code linters""" - run("flake8") + run("tox -e linters") @task diff --git a/tox.ini b/tox.ini index e46a0c47f5..a511cc6059 100644 --- a/tox.ini +++ b/tox.ini @@ -6,7 +6,7 @@ markers = [tox] minversion = 3.2.0 requires = tox-docker -envlist = {py35,py36,py37,py38,py39,pypy3}-{plain,hiredis}, flake8 +envlist = {py35,py36,py37,py38,py39,pypy3}-{plain,hiredis},linters [docker:master] name = master @@ -119,9 +119,12 @@ docker = lots-of-pythons commands = /usr/bin/echo -[testenv:flake8] +[testenv:linters] deps_files = dev_requirements.txt -commands = flake8 +docker= +commands = + flake8 + vulture redis whitelist.py --min-confidence 80 skipsdist = true skip_install = true @@ -131,18 +134,8 @@ basepython = pypy3 [testenv:pypy3-hiredis] basepython = pypy3 -#[testenv:codecov] -#deps = codecov -#commands = codecov -#passenv = -# REDIS_* -# CI -# CI_* -# CODECOV_* -# SHIPPABLE -# GITHUB_* -# VCS_* -# -#[testenv:covreport] -#deps = coverage -#commands = coverage report +[flake8] +exclude = + .venv, + .tox, + whitelist.py diff --git a/whitelist.py b/whitelist.py new file mode 100644 index 0000000000..891ccd6022 --- /dev/null +++ b/whitelist.py @@ -0,0 +1,12 @@ +exc_type # unused variable (/data/repos/redis/redis-py/redis/client.py:1045) +exc_value # unused variable (/data/repos/redis/redis-py/redis/client.py:1045) +traceback # unused variable (/data/repos/redis/redis-py/redis/client.py:1045) +exc_type # unused variable (/data/repos/redis/redis-py/redis/client.py:1211) +exc_value # unused variable (/data/repos/redis/redis-py/redis/client.py:1211) +traceback # unused variable (/data/repos/redis/redis-py/redis/client.py:1211) +exc_type # unused variable (/data/repos/redis/redis-py/redis/client.py:1589) +exc_value # unused variable (/data/repos/redis/redis-py/redis/client.py:1589) +traceback # unused variable (/data/repos/redis/redis-py/redis/client.py:1589) +exc_type # unused variable (/data/repos/redis/redis-py/redis/lock.py:156) +exc_value # unused variable (/data/repos/redis/redis-py/redis/lock.py:156) +traceback # unused variable (/data/repos/redis/redis-py/redis/lock.py:156) From ac378a90dcb9f2c1470f77c426f7731cc1f32580 Mon Sep 17 00:00:00 2001 From: Anas Date: Tue, 2 Nov 2021 10:13:04 +0200 Subject: [PATCH 05/22] Added boolean parsing to PEXPIRE and PEXPIREAT (#1665) --- redis/client.py | 3 ++- tests/test_commands.py | 24 ++++++++++++------------ 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/redis/client.py b/redis/client.py index 93979569b1..3768c2e5e1 100755 --- a/redis/client.py +++ b/redis/client.py @@ -639,7 +639,8 @@ class Redis(RedisModuleCommands, CoreCommands, object): """ RESPONSE_CALLBACKS = { **string_keys_to_dict( - 'AUTH COPY EXPIRE EXPIREAT HEXISTS HMSET LMOVE BLMOVE MOVE ' + 'AUTH COPY EXPIRE EXPIREAT PEXPIRE PEXPIREAT ' + 'HEXISTS HMSET LMOVE BLMOVE MOVE ' 'MSETNX PERSIST PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX', bool ), diff --git a/tests/test_commands.py b/tests/test_commands.py index 998bc0f0e6..27e548f4f9 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -843,9 +843,9 @@ def test_exists_contains(self, r): assert 'a' in r def test_expire(self, r): - assert not r.expire('a', 10) + assert r.expire('a', 10) is False r['a'] = 'foo' - assert r.expire('a', 10) + assert r.expire('a', 10) is True assert 0 < r.ttl('a') <= 10 assert r.persist('a') assert r.ttl('a') == -1 @@ -853,18 +853,18 @@ def test_expire(self, r): def test_expireat_datetime(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) r['a'] = 'foo' - assert r.expireat('a', expire_at) + assert r.expireat('a', expire_at) is True assert 0 < r.ttl('a') <= 61 def test_expireat_no_key(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) - assert not r.expireat('a', expire_at) + assert r.expireat('a', expire_at) is False def test_expireat_unixtime(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) r['a'] = 'foo' expire_at_seconds = int(time.mktime(expire_at.timetuple())) - assert r.expireat('a', expire_at_seconds) + assert r.expireat('a', expire_at_seconds) is True assert 0 < r.ttl('a') <= 61 def test_get_and_set(self, r): @@ -1007,9 +1007,9 @@ def test_msetnx(self, r): @skip_if_server_version_lt('2.6.0') def test_pexpire(self, r): - assert not r.pexpire('a', 60000) + assert r.pexpire('a', 60000) is False r['a'] = 'foo' - assert r.pexpire('a', 60000) + assert r.pexpire('a', 60000) is True assert 0 < r.pttl('a') <= 60000 assert r.persist('a') assert r.pttl('a') == -1 @@ -1018,20 +1018,20 @@ def test_pexpire(self, r): def test_pexpireat_datetime(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) r['a'] = 'foo' - assert r.pexpireat('a', expire_at) + assert r.pexpireat('a', expire_at) is True assert 0 < r.pttl('a') <= 61000 @skip_if_server_version_lt('2.6.0') def test_pexpireat_no_key(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) - assert not r.pexpireat('a', expire_at) + assert r.pexpireat('a', expire_at) is False @skip_if_server_version_lt('2.6.0') def test_pexpireat_unixtime(self, r): expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) r['a'] = 'foo' expire_at_seconds = int(time.mktime(expire_at.timetuple())) * 1000 - assert r.pexpireat('a', expire_at_seconds) + assert r.pexpireat('a', expire_at_seconds) is True assert 0 < r.pttl('a') <= 61000 @skip_if_server_version_lt('2.6.0') @@ -1049,9 +1049,9 @@ def test_psetex_timedelta(self, r): @skip_if_server_version_lt('2.6.0') def test_pttl(self, r): - assert not r.pexpire('a', 10000) + assert r.pexpire('a', 10000) is False r['a'] = '1' - assert r.pexpire('a', 10000) + assert r.pexpire('a', 10000) is True assert 0 < r.pttl('a') <= 10000 assert r.persist('a') assert r.pttl('a') == -1 From 62a895630ddba1aa560f6b6bbdfe9ef960c96415 Mon Sep 17 00:00:00 2001 From: Chayim Date: Tue, 2 Nov 2021 11:55:07 +0200 Subject: [PATCH 06/22] Improved JSON accuracy (#1666) --- redis/commands/helpers.py | 2 + redis/commands/json/__init__.py | 8 +- redis/commands/json/commands.py | 93 +++++++++++++--------- redis/commands/json/decoders.py | 12 +++ redis/commands/json/path.py | 11 +-- requirements.txt | 1 + setup.py | 3 + tests/test_json.py | 132 ++++++++++++++++++++++++-------- tox.ini | 4 +- 9 files changed, 186 insertions(+), 80 deletions(-) create mode 100644 redis/commands/json/decoders.py create mode 100644 requirements.txt diff --git a/redis/commands/helpers.py b/redis/commands/helpers.py index a92c02503e..48ee5568e6 100644 --- a/redis/commands/helpers.py +++ b/redis/commands/helpers.py @@ -22,6 +22,8 @@ def nativestr(x): def delist(x): """Given a list of binaries, return the stringified version.""" + if x is None: + return x return [nativestr(obj) for obj in x] diff --git a/redis/commands/json/__init__.py b/redis/commands/json/__init__.py index 978370553a..3149bb8f50 100644 --- a/redis/commands/json/__init__.py +++ b/redis/commands/json/__init__.py @@ -1,5 +1,9 @@ from json import JSONDecoder, JSONEncoder +from .decoders import ( + int_or_list, + int_or_none +) from .helpers import bulk_of_jsons from ..helpers import nativestr, delist from .commands import JSONCommands @@ -48,13 +52,13 @@ def __init__( "JSON.ARRAPPEND": int, "JSON.ARRINDEX": int, "JSON.ARRINSERT": int, - "JSON.ARRLEN": int, + "JSON.ARRLEN": int_or_none, "JSON.ARRPOP": self._decode, "JSON.ARRTRIM": int, "JSON.OBJLEN": int, "JSON.OBJKEYS": delist, # "JSON.RESP": delist, - "JSON.DEBUG": int, + "JSON.DEBUG": int_or_list, } self.client = client diff --git a/redis/commands/json/commands.py b/redis/commands/json/commands.py index 2f8039f8bf..fb00e220aa 100644 --- a/redis/commands/json/commands.py +++ b/redis/commands/json/commands.py @@ -1,5 +1,7 @@ -from .path import Path, str_path +from .path import Path from .helpers import decode_dict_keys +from deprecated import deprecated +from redis.exceptions import DataError class JSONCommands: @@ -9,7 +11,7 @@ def arrappend(self, name, path=Path.rootPath(), *args): """Append the objects ``args`` to the array under the ``path` in key ``name``. """ - pieces = [name, str_path(path)] + pieces = [name, str(path)] for o in args: pieces.append(self._encode(o)) return self.execute_command("JSON.ARRAPPEND", *pieces) @@ -23,7 +25,7 @@ def arrindex(self, name, path, scalar, start=0, stop=-1): and exclusive ``stop`` indices. """ return self.execute_command( - "JSON.ARRINDEX", name, str_path(path), self._encode(scalar), + "JSON.ARRINDEX", name, str(path), self._encode(scalar), start, stop ) @@ -31,67 +33,64 @@ def arrinsert(self, name, path, index, *args): """Insert the objects ``args`` to the array at index ``index`` under the ``path` in key ``name``. """ - pieces = [name, str_path(path), index] + pieces = [name, str(path), index] for o in args: pieces.append(self._encode(o)) return self.execute_command("JSON.ARRINSERT", *pieces) - def forget(self, name, path=Path.rootPath()): - """Alias for jsondel (delete the JSON value).""" - return self.execute_command("JSON.FORGET", name, str_path(path)) - def arrlen(self, name, path=Path.rootPath()): """Return the length of the array JSON value under ``path`` at key``name``. """ - return self.execute_command("JSON.ARRLEN", name, str_path(path)) + return self.execute_command("JSON.ARRLEN", name, str(path)) def arrpop(self, name, path=Path.rootPath(), index=-1): """Pop the element at ``index`` in the array JSON value under ``path`` at key ``name``. """ - return self.execute_command("JSON.ARRPOP", name, str_path(path), index) + return self.execute_command("JSON.ARRPOP", name, str(path), index) def arrtrim(self, name, path, start, stop): """Trim the array JSON value under ``path`` at key ``name`` to the inclusive range given by ``start`` and ``stop``. """ - return self.execute_command("JSON.ARRTRIM", name, str_path(path), + return self.execute_command("JSON.ARRTRIM", name, str(path), start, stop) def type(self, name, path=Path.rootPath()): """Get the type of the JSON value under ``path`` from key ``name``.""" - return self.execute_command("JSON.TYPE", name, str_path(path)) + return self.execute_command("JSON.TYPE", name, str(path)) def resp(self, name, path=Path.rootPath()): """Return the JSON value under ``path`` at key ``name``.""" - return self.execute_command("JSON.RESP", name, str_path(path)) + return self.execute_command("JSON.RESP", name, str(path)) def objkeys(self, name, path=Path.rootPath()): """Return the key names in the dictionary JSON value under ``path`` at key ``name``.""" - return self.execute_command("JSON.OBJKEYS", name, str_path(path)) + return self.execute_command("JSON.OBJKEYS", name, str(path)) def objlen(self, name, path=Path.rootPath()): """Return the length of the dictionary JSON value under ``path`` at key ``name``. """ - return self.execute_command("JSON.OBJLEN", name, str_path(path)) + return self.execute_command("JSON.OBJLEN", name, str(path)) def numincrby(self, name, path, number): """Increment the numeric (integer or floating point) JSON value under ``path`` at key ``name`` by the provided ``number``. """ return self.execute_command( - "JSON.NUMINCRBY", name, str_path(path), self._encode(number) + "JSON.NUMINCRBY", name, str(path), self._encode(number) ) + @deprecated(version='4.0.0', reason='deprecated since redisjson 1.0.0') def nummultby(self, name, path, number): """Multiply the numeric (integer or floating point) JSON value under ``path`` at key ``name`` with the provided ``number``. """ return self.execute_command( - "JSON.NUMMULTBY", name, str_path(path), self._encode(number) + "JSON.NUMMULTBY", name, str(path), self._encode(number) ) def clear(self, name, path=Path.rootPath()): @@ -102,11 +101,14 @@ def clear(self, name, path=Path.rootPath()): Return the count of cleared paths (ignoring non-array and non-objects paths). """ - return self.execute_command("JSON.CLEAR", name, str_path(path)) + return self.execute_command("JSON.CLEAR", name, str(path)) + + def delete(self, key, path=Path.rootPath()): + """Delete the JSON value stored at key ``key`` under ``path``.""" + return self.execute_command("JSON.DEL", key, str(path)) - def delete(self, name, path=Path.rootPath()): - """Delete the JSON value stored at key ``name`` under ``path``.""" - return self.execute_command("JSON.DEL", name, str_path(path)) + # forget is an alias for delete + forget = delete def get(self, name, *args, no_escape=False): """ @@ -125,7 +127,7 @@ def get(self, name, *args, no_escape=False): else: for p in args: - pieces.append(str_path(p)) + pieces.append(str(p)) # Handle case where key doesn't exist. The JSONDecoder would raise a # TypeError exception since it can't decode None @@ -134,13 +136,14 @@ def get(self, name, *args, no_escape=False): except TypeError: return None - def mget(self, path, *args): - """Get the objects stored as a JSON values under ``path`` from keys - ``args``. + def mget(self, keys, path): + """ + Get the objects stored as a JSON values under ``path``. ``keys`` + is a list of one or more keys. """ pieces = [] - pieces.extend(args) - pieces.append(str_path(path)) + pieces += keys + pieces.append(str(path)) return self.execute_command("JSON.MGET", *pieces) def set(self, name, path, obj, nx=False, xx=False, decode_keys=False): @@ -155,7 +158,7 @@ def set(self, name, path, obj, nx=False, xx=False, decode_keys=False): if decode_keys: obj = decode_dict_keys(obj) - pieces = [name, str_path(path), self._encode(obj)] + pieces = [name, str(path), self._encode(obj)] # Handle existential modifiers if nx and xx: @@ -169,29 +172,43 @@ def set(self, name, path, obj, nx=False, xx=False, decode_keys=False): pieces.append("XX") return self.execute_command("JSON.SET", *pieces) - def strlen(self, name, path=Path.rootPath()): + def strlen(self, name, path=None): """Return the length of the string JSON value under ``path`` at key ``name``. """ - return self.execute_command("JSON.STRLEN", name, str_path(path)) + pieces = [name] + if path is not None: + pieces.append(str(path)) + return self.execute_command("JSON.STRLEN", *pieces) def toggle(self, name, path=Path.rootPath()): """Toggle boolean value under ``path`` at key ``name``. returning the new value. """ - return self.execute_command("JSON.TOGGLE", name, str_path(path)) + return self.execute_command("JSON.TOGGLE", name, str(path)) - def strappend(self, name, string, path=Path.rootPath()): - """Append to the string JSON value under ``path`` at key ``name`` - the provided ``string``. + def strappend(self, name, value, path=Path.rootPath()): + """Append to the string JSON value. If two options are specified after + the key name, the path is determined to be the first. If a single + option is passed, then the rootpath (i.e Path.rootPath()) is used. """ + pieces = [name, str(path), value] return self.execute_command( - "JSON.STRAPPEND", name, str_path(path), self._encode(string) + "JSON.STRAPPEND", *pieces ) - def debug(self, name, path=Path.rootPath()): + def debug(self, subcommand, key=None, path=Path.rootPath()): """Return the memory usage in bytes of a value under ``path`` from key ``name``. """ - return self.execute_command("JSON.DEBUG", "MEMORY", - name, str_path(path)) + valid_subcommands = ["MEMORY", "HELP"] + if subcommand not in valid_subcommands: + raise DataError("The only valid subcommands are ", + str(valid_subcommands)) + pieces = [subcommand] + if subcommand == "MEMORY": + if key is None: + raise DataError("No key specified") + pieces.append(key) + pieces.append(str(path)) + return self.execute_command("JSON.DEBUG", *pieces) diff --git a/redis/commands/json/decoders.py b/redis/commands/json/decoders.py new file mode 100644 index 0000000000..0ee102a433 --- /dev/null +++ b/redis/commands/json/decoders.py @@ -0,0 +1,12 @@ +def int_or_list(b): + if isinstance(b, int): + return b + else: + return b + + +def int_or_none(b): + if b is None: + return None + if isinstance(b, int): + return b diff --git a/redis/commands/json/path.py b/redis/commands/json/path.py index dff86482df..6d87045155 100644 --- a/redis/commands/json/path.py +++ b/redis/commands/json/path.py @@ -1,11 +1,3 @@ -def str_path(p): - """Return the string representation of a path if it is of class Path.""" - if isinstance(p, Path): - return p.strPath - else: - return p - - class Path(object): """This class represents a path in a JSON value.""" @@ -19,3 +11,6 @@ def rootPath(): def __init__(self, path): """Make a new path based on the string representation in `path`.""" self.strPath = path + + def __repr__(self): + return self.strPath diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..9f8d5502da --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +deprecated diff --git a/setup.py b/setup.py index d0c81b40dd..6c712bd1d4 100644 --- a/setup.py +++ b/setup.py @@ -23,6 +23,9 @@ author="Redis Inc.", author_email="oss@redis.com", python_requires=">=3.6", + install_requires=[ + 'deprecated' + ], classifiers=[ "Development Status :: 5 - Production/Stable", "Environment :: Console", diff --git a/tests/test_json.py b/tests/test_json.py index c0b4d9ee4c..d78378bb9f 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -67,45 +67,52 @@ def test_jsonsetexistentialmodifiersshouldsucceed(self, client): with pytest.raises(Exception): client.json().set("obj", Path("foo"), "baz", nx=True, xx=True) + @pytest.mark.redismod def test_mgetshouldsucceed(self, client): client.json().set("1", Path.rootPath(), 1) client.json().set("2", Path.rootPath(), 2) - r = client.json().mget(Path.rootPath(), "1", "2") - e = [1, 2] - assert e == r + assert client.json().mget(["1"], Path.rootPath()) == [1] + + assert client.json().mget([1, 2], Path.rootPath()) == [1, 2] + @pytest.mark.redismod - @skip_ifmodversion_lt("99.99.99", - "ReJSON") # todo: update after the release - def test_clearShouldSucceed(self, client): + @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release + def test_clear(self, client): client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) assert 1 == client.json().clear("arr", Path.rootPath()) assert [] == client.json().get("arr") + @pytest.mark.redismod - def test_typeshouldsucceed(self, client): + def test_type(self, client): client.json().set("1", Path.rootPath(), 1) + assert b"integer" == client.json().type("1", Path.rootPath()) assert b"integer" == client.json().type("1") + @pytest.mark.redismod - def test_numincrbyshouldsucceed(self, client): + def test_numincrby(self, client): client.json().set("num", Path.rootPath(), 1) assert 2 == client.json().numincrby("num", Path.rootPath(), 1) assert 2.5 == client.json().numincrby("num", Path.rootPath(), 0.5) assert 1.25 == client.json().numincrby("num", Path.rootPath(), -1.25) + @pytest.mark.redismod - def test_nummultbyshouldsucceed(self, client): + def test_nummultby(self, client): client.json().set("num", Path.rootPath(), 1) - assert 2 == client.json().nummultby("num", Path.rootPath(), 2) - assert 5 == client.json().nummultby("num", Path.rootPath(), 2.5) - assert 2.5 == client.json().nummultby("num", Path.rootPath(), 0.5) + + with pytest.deprecated_call(): + assert 2 == client.json().nummultby("num", Path.rootPath(), 2) + assert 5 == client.json().nummultby("num", Path.rootPath(), 2.5) + assert 2.5 == client.json().nummultby("num", Path.rootPath(), 0.5) + @pytest.mark.redismod - @skip_ifmodversion_lt("99.99.99", - "ReJSON") # todo: update after the release - def test_toggleShouldSucceed(self, client): + @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release + def test_toggle(self, client): client.json().set("bool", Path.rootPath(), False) assert client.json().toggle("bool", Path.rootPath()) assert not client.json().toggle("bool", Path.rootPath()) @@ -114,39 +121,54 @@ def test_toggleShouldSucceed(self, client): with pytest.raises(redis.exceptions.ResponseError): client.json().toggle("num", Path.rootPath()) + @pytest.mark.redismod - def test_strappendshouldsucceed(self, client): - client.json().set("str", Path.rootPath(), "foo") - assert 6 == client.json().strappend("str", "bar", Path.rootPath()) - assert "foobar" == client.json().get("str", Path.rootPath()) + def test_strappend(self, client): + client.json().set("jsonkey", Path.rootPath(), 'foo') + import json + assert 6 == client.json().strappend("jsonkey", json.dumps('bar')) + with pytest.raises(redis.exceptions.ResponseError): + assert 6 == client.json().strappend("jsonkey", 'bar') + assert "foobar" == client.json().get("jsonkey", Path.rootPath()) + @pytest.mark.redismod def test_debug(self, client): client.json().set("str", Path.rootPath(), "foo") - assert 24 == client.json().debug("str", Path.rootPath()) + assert 24 == client.json().debug("MEMORY", "str", Path.rootPath()) + assert 24 == client.json().debug("MEMORY", "str") + + # technically help is valid + assert isinstance(client.json().debug("HELP"), list) + @pytest.mark.redismod - def test_strlenshouldsucceed(self, client): + def test_strlen(self, client): client.json().set("str", Path.rootPath(), "foo") assert 3 == client.json().strlen("str", Path.rootPath()) - client.json().strappend("str", "bar", Path.rootPath()) + import json + client.json().strappend("str", json.dumps("bar"), Path.rootPath()) assert 6 == client.json().strlen("str", Path.rootPath()) + assert 6 == client.json().strlen("str") + @pytest.mark.redismod - def test_arrappendshouldsucceed(self, client): + def test_arrappend(self, client): client.json().set("arr", Path.rootPath(), [1]) assert 2 == client.json().arrappend("arr", Path.rootPath(), 2) assert 4 == client.json().arrappend("arr", Path.rootPath(), 3, 4) assert 7 == client.json().arrappend("arr", Path.rootPath(), *[5, 6, 7]) + @pytest.mark.redismod - def testArrIndexShouldSucceed(self, client): + def test_arrindex(self, client): client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) assert 1 == client.json().arrindex("arr", Path.rootPath(), 1) assert -1 == client.json().arrindex("arr", Path.rootPath(), 1, 2) + @pytest.mark.redismod - def test_arrinsertshouldsucceed(self, client): + def test_arrinsert(self, client): client.json().set("arr", Path.rootPath(), [0, 4]) assert 5 - -client.json().arrinsert( "arr", @@ -160,13 +182,22 @@ def test_arrinsertshouldsucceed(self, client): ) assert [0, 1, 2, 3, 4] == client.json().get("arr") + # test prepends + client.json().set("val2", Path.rootPath(), [5, 6, 7, 8, 9]) + client.json().arrinsert("val2", Path.rootPath(), 0, ['some', 'thing']) + assert client.json().get("val2") == [["some", "thing"], 5, 6, 7, 8, 9] + + @pytest.mark.redismod - def test_arrlenshouldsucceed(self, client): + def test_arrlen(self, client): client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) assert 5 == client.json().arrlen("arr", Path.rootPath()) + assert 5 == client.json().arrlen("arr") + assert client.json().arrlen('fakekey') is None + @pytest.mark.redismod - def test_arrpopshouldsucceed(self, client): + def test_arrpop(self, client): client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) assert 4 == client.json().arrpop("arr", Path.rootPath(), 4) assert 3 == client.json().arrpop("arr", Path.rootPath(), -1) @@ -174,22 +205,50 @@ def test_arrpopshouldsucceed(self, client): assert 0 == client.json().arrpop("arr", Path.rootPath(), 0) assert [1] == client.json().get("arr") + # test out of bounds + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 4 == client.json().arrpop("arr", Path.rootPath(), 99) + + # none test + client.json().set("arr", Path.rootPath(), []) + assert client.json().arrpop("arr") is None + + @pytest.mark.redismod - def test_arrtrimshouldsucceed(self, client): + def test_arrtrim(self, client): client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) assert 3 == client.json().arrtrim("arr", Path.rootPath(), 1, 3) assert [1, 2, 3] == client.json().get("arr") + # <0 test, should be 0 equivalent + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 0 == client.json().arrtrim("arr", Path.rootPath(), -1, 3) + + # testing stop > end + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 2 == client.json().arrtrim("arr", Path.rootPath(), 3, 99) + + # start > array size and stop + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 0 == client.json().arrtrim("arr", Path.rootPath(), 9, 1) + + # all larger + client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) + assert 0 == client.json().arrtrim("arr", Path.rootPath(), 9, 11) + + @pytest.mark.redismod - def test_respshouldsucceed(self, client): + def test_resp(self, client): obj = {"foo": "bar", "baz": 1, "qaz": True} client.json().set("obj", Path.rootPath(), obj) assert b"bar" == client.json().resp("obj", Path("foo")) assert 1 == client.json().resp("obj", Path("baz")) assert client.json().resp("obj", Path("qaz")) + assert isinstance(client.json().resp("obj"), list) + @pytest.mark.redismod - def test_objkeysshouldsucceed(self, client): + def test_objkeys(self, client): obj = {"foo": "bar", "baz": "qaz"} client.json().set("obj", Path.rootPath(), obj) keys = client.json().objkeys("obj", Path.rootPath()) @@ -198,12 +257,23 @@ def test_objkeysshouldsucceed(self, client): exp.sort() assert exp == keys + client.json().set("obj", Path.rootPath(), obj) + keys = client.json().objkeys("obj") + assert keys == list(obj.keys()) + + assert client.json().objkeys("fakekey") is None + + @pytest.mark.redismod - def test_objlenshouldsucceed(self, client): + def test_objlen(self, client): obj = {"foo": "bar", "baz": "qaz"} client.json().set("obj", Path.rootPath(), obj) assert len(obj) == client.json().objlen("obj", Path.rootPath()) + client.json().set("obj", Path.rootPath(), obj) + assert len(obj) == client.json().objlen("obj") + + # @pytest.mark.pipeline # @pytest.mark.redismod # def test_pipelineshouldsucceed(client): diff --git a/tox.ini b/tox.ini index a511cc6059..ad6be36723 100644 --- a/tox.ini +++ b/tox.ini @@ -90,7 +90,9 @@ volumes = [testenv] -deps = -r {toxinidir}/dev_requirements.txt +deps = + -r {toxinidir}/requirements.txt + -r {toxinidir}/dev_requirements.txt docker = master replica From e8688efbbf8a14e2a506c21ef16c0d72559cff45 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 3 Nov 2021 19:06:53 +0200 Subject: [PATCH 07/22] Added two new marks: onlycluster to mark tests to be run only with cluster mode redis, and onlynoncluster to mark tests to be run only with non-cluster redis. --- tasks.py | 9 +- tests/conftest.py | 23 +- tests/test_cluster.py | 7 +- tests/test_commands.py | 7 +- tests/test_connection.py | 8 +- tests/test_connection_pool.py | 13 +- tests/test_json.py | 29 +- tests/test_lock.py | 6 +- tests/test_monitor.py | 5 +- tests/test_pipeline.py | 31 +- tests/test_pubsub.py | 17 +- tests/test_scripting.py | 5 +- tests/test_search.py | 2046 ++++++++++++++++----------------- tests/test_sentinel.py | 3 - tests/test_timeseries.py | 4 +- tox.ini | 17 +- 16 files changed, 1086 insertions(+), 1144 deletions(-) diff --git a/tasks.py b/tasks.py index 2631c702d4..c32bde7876 100644 --- a/tasks.py +++ b/tasks.py @@ -20,6 +20,13 @@ def devenv(c): run(cmd) +@task +def cluster(c): + """Run all Redis Cluster tests.""" + print("Starting RedisCluster tests") + run("tox -e cluster") + + @task def linters(c): """Run code linters""" @@ -42,8 +49,6 @@ def tests(c): """ print("Starting Redis tests") run("tox -e plain -e hiredis") - print("Starting RedisCluster tests") - run("tox -e plain -e hiredis -- --redis-url=redis://localhost:16379/0") @task diff --git a/tests/conftest.py b/tests/conftest.py index df809bf81d..35d2f6cbde 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,9 +55,11 @@ def pytest_sessionstart(session): REDIS_INFO["cluster_enabled"] = cluster_enabled # module info - redismod_url = session.config.getoption("--redismod-url") - info = _get_info(redismod_url) - REDIS_INFO["modules"] = info["modules"] + markers = session.config.getoption("-m") + if 'redismod' in markers and 'not redismod' not in markers: + redismod_url = session.config.getoption("--redismod-url") + info = _get_info(redismod_url) + REDIS_INFO["modules"] = info["modules"] if cluster_enabled: cluster_nodes = session.config.getoption("--cluster-nodes") @@ -103,8 +105,8 @@ def skip_unless_arch_bits(arch_bits): def skip_ifmodversion_lt(min_version: str, module_name: str): - modules = REDIS_INFO["modules"] - if modules == []: + modules = REDIS_INFO.get("modules") + if modules is None or modules == []: return pytest.mark.skipif(True, reason="No redis modules found") for j in modules: @@ -117,17 +119,6 @@ def skip_ifmodversion_lt(min_version: str, module_name: str): raise AttributeError("No redis module named {}".format(module_name)) -def skip_if_cluster_mode(): - return pytest.mark.skipif(REDIS_INFO["cluster_enabled"], - reason="This test isn't supported with cluster " - "mode") - - -def skip_if_not_cluster_mode(): - return pytest.mark.skipif(not REDIS_INFO["cluster_enabled"], - reason="Cluster-mode is required for this test") - - def _get_client(cls, request, single_connection_client=True, flushdb=True, from_url=None, **kwargs): diff --git a/tests/test_cluster.py b/tests/test_cluster.py index fa78c04710..5027759ea3 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -20,7 +20,6 @@ from redis.crc import key_slot from .conftest import ( - skip_if_not_cluster_mode, _get_client, skip_if_server_version_lt ) @@ -176,7 +175,7 @@ def ok_response(connection, *args, **options): assert prev_primary.server_type == REPLICA -@skip_if_not_cluster_mode() +@pytest.mark.onlycluster class TestRedisClusterObj: def test_host_port_startup_node(self): """ @@ -559,7 +558,7 @@ def on_connect(connection): assert mock.called is True -@skip_if_not_cluster_mode() +@pytest.mark.onlycluster class TestClusterRedisCommands: def test_case_insensitive_command_names(self, r): assert r.cluster_response_callbacks['cluster addslots'] == \ @@ -1081,7 +1080,7 @@ def test_client_kill(self, r, r2): assert clients[0].get('name') == 'redis-py-c1' -@skip_if_not_cluster_mode() +@pytest.mark.onlycluster class TestNodesManager: def test_load_balancer(self, r): n_manager = r.nodes_manager diff --git a/tests/test_commands.py b/tests/test_commands.py index 27e548f4f9..de29e2035c 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -11,7 +11,6 @@ from redis.commands import CommandsParser from .conftest import ( _get_client, - skip_if_cluster_mode, skip_if_server_version_gte, skip_if_server_version_lt, skip_unless_arch_bits, @@ -47,7 +46,7 @@ def get_stream_message(client, stream, message_id): # RESPONSE CALLBACKS -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster class TestResponseCallbacks: "Tests for the response callback system" @@ -62,7 +61,7 @@ def test_case_insensitive_command_names(self, r): assert r.response_callbacks['del'] == r.response_callbacks['DEL'] -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster class TestRedisCommands: def test_command_on_invalid_key_type(self, r): r.lpush('a', '1') @@ -3681,7 +3680,7 @@ def test_replicaof(self, r): assert r.replicaof("NO", "ONE") -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster class TestBinarySave: def test_binary_get_set(self, r): diff --git a/tests/test_connection.py b/tests/test_connection.py index 2ca858d263..14c77a6eb2 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -4,11 +4,11 @@ from redis.exceptions import InvalidResponse, ModuleError from redis.utils import HIREDIS_AVAILABLE -from .conftest import skip_if_server_version_lt, skip_if_cluster_mode +from .conftest import skip_if_server_version_lt @pytest.mark.skipif(HIREDIS_AVAILABLE, reason='PythonParser only') -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster def test_invalid_response(r): raw = b'x' parser = r.connection._parser @@ -18,14 +18,14 @@ def test_invalid_response(r): assert str(cm.value) == 'Protocol Error: %r' % raw -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster @skip_if_server_version_lt('4.0.0') def test_loaded_modules(r, modclient): assert r.loaded_modules == [] assert 'rejson' in modclient.loaded_modules.keys() -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster @skip_if_server_version_lt('4.0.0') def test_loading_external_modules(r, modclient): def inner(): diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 4708057e98..7f7a6eea07 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -7,8 +7,7 @@ from threading import Thread from redis.connection import ssl_available, to_bool -from .conftest import skip_if_server_version_lt, skip_if_cluster_mode,\ - _get_client +from .conftest import skip_if_server_version_lt, _get_client from .test_pubsub import wait_for_message @@ -482,7 +481,7 @@ def test_on_connect_error(self): assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.8') def test_busy_loading_disconnects_socket(self, r): """ @@ -493,7 +492,7 @@ def test_busy_loading_disconnects_socket(self, r): r.execute_command('DEBUG', 'ERROR', 'LOADING fake message') assert not r.connection._sock - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.8') def test_busy_loading_from_pipeline_immediate_command(self, r): """ @@ -509,7 +508,7 @@ def test_busy_loading_from_pipeline_immediate_command(self, r): assert len(pool._available_connections) == 1 assert not pool._available_connections[0]._sock - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.8') def test_busy_loading_from_pipeline(self, r): """ @@ -567,7 +566,7 @@ def test_connect_invalid_password_supplied(self, r): r.execute_command('DEBUG', 'ERROR', 'ERR invalid password') -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster class TestMultiConnectionClient: @pytest.fixture() def r(self, request): @@ -581,7 +580,7 @@ def test_multi_connection_command(self, r): assert r.get('a') == b'123' -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster class TestHealthCheck: interval = 60 diff --git a/tests/test_json.py b/tests/test_json.py index d78378bb9f..f1d42489e3 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,7 +1,7 @@ import pytest import redis from redis.commands.json.path import Path -from .conftest import skip_ifmodversion_lt, skip_if_cluster_mode +from .conftest import skip_ifmodversion_lt @pytest.fixture @@ -10,7 +10,7 @@ def client(modclient): return modclient -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster class TestJson: @pytest.mark.redismod def test_json_setbinarykey(self, client): @@ -67,7 +67,6 @@ def test_jsonsetexistentialmodifiersshouldsucceed(self, client): with pytest.raises(Exception): client.json().set("obj", Path("foo"), "baz", nx=True, xx=True) - @pytest.mark.redismod def test_mgetshouldsucceed(self, client): client.json().set("1", Path.rootPath(), 1) @@ -76,22 +75,20 @@ def test_mgetshouldsucceed(self, client): assert client.json().mget([1, 2], Path.rootPath()) == [1, 2] - @pytest.mark.redismod - @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release + @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the + # release def test_clear(self, client): client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) assert 1 == client.json().clear("arr", Path.rootPath()) assert [] == client.json().get("arr") - @pytest.mark.redismod def test_type(self, client): client.json().set("1", Path.rootPath(), 1) assert b"integer" == client.json().type("1", Path.rootPath()) assert b"integer" == client.json().type("1") - @pytest.mark.redismod def test_numincrby(self, client): client.json().set("num", Path.rootPath(), 1) @@ -99,7 +96,6 @@ def test_numincrby(self, client): assert 2.5 == client.json().numincrby("num", Path.rootPath(), 0.5) assert 1.25 == client.json().numincrby("num", Path.rootPath(), -1.25) - @pytest.mark.redismod def test_nummultby(self, client): client.json().set("num", Path.rootPath(), 1) @@ -109,9 +105,9 @@ def test_nummultby(self, client): assert 5 == client.json().nummultby("num", Path.rootPath(), 2.5) assert 2.5 == client.json().nummultby("num", Path.rootPath(), 0.5) - @pytest.mark.redismod - @skip_ifmodversion_lt("99.99.99", "ReJSON") # todo: update after the release + @skip_ifmodversion_lt("99.99.99", + "ReJSON") # todo: update after the release def test_toggle(self, client): client.json().set("bool", Path.rootPath(), False) assert client.json().toggle("bool", Path.rootPath()) @@ -121,7 +117,6 @@ def test_toggle(self, client): with pytest.raises(redis.exceptions.ResponseError): client.json().toggle("num", Path.rootPath()) - @pytest.mark.redismod def test_strappend(self, client): client.json().set("jsonkey", Path.rootPath(), 'foo') @@ -131,7 +126,6 @@ def test_strappend(self, client): assert 6 == client.json().strappend("jsonkey", 'bar') assert "foobar" == client.json().get("jsonkey", Path.rootPath()) - @pytest.mark.redismod def test_debug(self, client): client.json().set("str", Path.rootPath(), "foo") @@ -141,7 +135,6 @@ def test_debug(self, client): # technically help is valid assert isinstance(client.json().debug("HELP"), list) - @pytest.mark.redismod def test_strlen(self, client): client.json().set("str", Path.rootPath(), "foo") @@ -151,7 +144,6 @@ def test_strlen(self, client): assert 6 == client.json().strlen("str", Path.rootPath()) assert 6 == client.json().strlen("str") - @pytest.mark.redismod def test_arrappend(self, client): client.json().set("arr", Path.rootPath(), [1]) @@ -159,14 +151,12 @@ def test_arrappend(self, client): assert 4 == client.json().arrappend("arr", Path.rootPath(), 3, 4) assert 7 == client.json().arrappend("arr", Path.rootPath(), *[5, 6, 7]) - @pytest.mark.redismod def test_arrindex(self, client): client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) assert 1 == client.json().arrindex("arr", Path.rootPath(), 1) assert -1 == client.json().arrindex("arr", Path.rootPath(), 1, 2) - @pytest.mark.redismod def test_arrinsert(self, client): client.json().set("arr", Path.rootPath(), [0, 4]) @@ -187,7 +177,6 @@ def test_arrinsert(self, client): client.json().arrinsert("val2", Path.rootPath(), 0, ['some', 'thing']) assert client.json().get("val2") == [["some", "thing"], 5, 6, 7, 8, 9] - @pytest.mark.redismod def test_arrlen(self, client): client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) @@ -195,7 +184,6 @@ def test_arrlen(self, client): assert 5 == client.json().arrlen("arr") assert client.json().arrlen('fakekey') is None - @pytest.mark.redismod def test_arrpop(self, client): client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) @@ -213,7 +201,6 @@ def test_arrpop(self, client): client.json().set("arr", Path.rootPath(), []) assert client.json().arrpop("arr") is None - @pytest.mark.redismod def test_arrtrim(self, client): client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) @@ -236,7 +223,6 @@ def test_arrtrim(self, client): client.json().set("arr", Path.rootPath(), [0, 1, 2, 3, 4]) assert 0 == client.json().arrtrim("arr", Path.rootPath(), 9, 11) - @pytest.mark.redismod def test_resp(self, client): obj = {"foo": "bar", "baz": 1, "qaz": True} @@ -246,7 +232,6 @@ def test_resp(self, client): assert client.json().resp("obj", Path("qaz")) assert isinstance(client.json().resp("obj"), list) - @pytest.mark.redismod def test_objkeys(self, client): obj = {"foo": "bar", "baz": "qaz"} @@ -263,7 +248,6 @@ def test_objkeys(self, client): assert client.json().objkeys("fakekey") is None - @pytest.mark.redismod def test_objlen(self, client): obj = {"foo": "bar", "baz": "qaz"} @@ -273,7 +257,6 @@ def test_objlen(self, client): client.json().set("obj", Path.rootPath(), obj) assert len(obj) == client.json().objlen("obj") - # @pytest.mark.pipeline # @pytest.mark.redismod # def test_pipelineshouldsucceed(client): diff --git a/tests/test_lock.py b/tests/test_lock.py index ab62dfc820..66148edcfc 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -4,10 +4,10 @@ from redis.exceptions import LockError, LockNotOwnedError from redis.client import Redis from redis.lock import Lock -from .conftest import _get_client, skip_if_cluster_mode +from .conftest import _get_client -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster class TestLock: @pytest.fixture() def r_decoded(self, request): @@ -221,7 +221,7 @@ def test_reacquiring_lock_no_longer_owned_raises_error(self, r): lock.reacquire() -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster class TestLockClassSelection: def test_lock_class_argument(self, r): class MyLock: diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 5d065c9206..cdce0562c0 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -1,7 +1,8 @@ -from .conftest import wait_for_command, skip_if_cluster_mode +import pytest +from .conftest import wait_for_command -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster class TestMonitor: def test_wait_command_not_found(self, r): "Make sure the wait_for_command func works when command is not found" diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 8fadf46bf1..a759bc944e 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -1,8 +1,7 @@ import pytest import redis -from .conftest import wait_for_command, skip_if_server_version_lt, \ - skip_if_cluster_mode +from .conftest import wait_for_command, skip_if_server_version_lt class TestPipeline: @@ -60,7 +59,7 @@ def test_pipeline_no_transaction(self, r): assert r['b'] == b'b1' assert r['c'] == b'c1' - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_pipeline_no_transaction_watch(self, r): r['a'] = 0 @@ -72,7 +71,7 @@ def test_pipeline_no_transaction_watch(self, r): pipe.set('a', int(a) + 1) assert pipe.execute() == [True] - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_pipeline_no_transaction_watch_failure(self, r): r['a'] = 0 @@ -132,7 +131,7 @@ def test_exec_error_raised(self, r): assert pipe.set('z', 'zzz').execute() == [True] assert r['z'] == b'zzz' - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_transaction_with_empty_error_command(self, r): """ Commands with custom EMPTY_ERROR functionality return their default @@ -147,7 +146,7 @@ def test_transaction_with_empty_error_command(self, r): assert result[1] == [] assert result[2] - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_pipeline_with_empty_error_command(self, r): """ Commands with custom EMPTY_ERROR functionality return their default @@ -176,7 +175,7 @@ def test_parse_error_raised(self, r): assert pipe.set('z', 'zzz').execute() == [True] assert r['z'] == b'zzz' - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_parse_error_raised_transaction(self, r): with r.pipeline() as pipe: pipe.multi() @@ -192,7 +191,7 @@ def test_parse_error_raised_transaction(self, r): assert pipe.set('z', 'zzz').execute() == [True] assert r['z'] == b'zzz' - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_watch_succeed(self, r): r['a'] = 1 r['b'] = 2 @@ -210,7 +209,7 @@ def test_watch_succeed(self, r): assert pipe.execute() == [True] assert not pipe.watching - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_watch_failure(self, r): r['a'] = 1 r['b'] = 2 @@ -225,7 +224,7 @@ def test_watch_failure(self, r): assert not pipe.watching - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_watch_failure_in_empty_transaction(self, r): r['a'] = 1 r['b'] = 2 @@ -239,7 +238,7 @@ def test_watch_failure_in_empty_transaction(self, r): assert not pipe.watching - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_unwatch(self, r): r['a'] = 1 r['b'] = 2 @@ -252,7 +251,7 @@ def test_unwatch(self, r): pipe.get('a') assert pipe.execute() == [b'1'] - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_watch_exec_no_unwatch(self, r): r['a'] = 1 r['b'] = 2 @@ -273,7 +272,7 @@ def test_watch_exec_no_unwatch(self, r): unwatch_command = wait_for_command(r, m, 'UNWATCH') assert unwatch_command is None, "should not send UNWATCH" - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_watch_reset_unwatch(self, r): r['a'] = 1 @@ -288,7 +287,7 @@ def test_watch_reset_unwatch(self, r): assert unwatch_command is not None assert unwatch_command['command'] == 'UNWATCH' - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_transaction_callable(self, r): r['a'] = 1 r['b'] = 2 @@ -313,7 +312,7 @@ def my_transaction(pipe): assert result == [True] assert r['c'] == b'4' - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_transaction_callable_returns_value_from_callable(self, r): def callback(pipe): # No need to do anything here since we only want the return value @@ -368,7 +367,7 @@ def test_pipeline_with_bitfield(self, r): assert pipe == pipe2 assert response == [True, [0, 0, 15, 15, 14], b'1'] - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.0.0') def test_pipeline_discard(self, r): diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index ebb96de58b..ed2e244fed 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -7,8 +7,7 @@ import redis from redis.exceptions import ConnectionError -from .conftest import _get_client, skip_if_cluster_mode, \ - skip_if_server_version_lt +from .conftest import _get_client, skip_if_server_version_lt def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): @@ -120,7 +119,7 @@ def test_resubscribe_to_channels_on_reconnection(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'channel') self._test_resubscribe_on_reconnection(**kwargs) - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_resubscribe_to_patterns_on_reconnection(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') self._test_resubscribe_on_reconnection(**kwargs) @@ -175,7 +174,7 @@ def test_subscribe_property_with_channels(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'channel') self._test_subscribed_property(**kwargs) - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_subscribe_property_with_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') self._test_subscribed_property(**kwargs) @@ -219,7 +218,7 @@ def test_sub_unsub_resub_channels(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'channel') self._test_sub_unsub_resub(**kwargs) - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_sub_unsub_resub_patterns(self, r): kwargs = make_subscribe_test_data(r.pubsub(), 'pattern') self._test_sub_unsub_resub(**kwargs) @@ -307,7 +306,7 @@ def test_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message('message', 'foo', 'test message') - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_pattern_message_handler(self, r): p = r.pubsub(ignore_subscribe_messages=True) p.psubscribe(**{'f*': self.message_handler}) @@ -327,7 +326,7 @@ def test_unicode_channel_message_handler(self, r): assert wait_for_message(p) is None assert self.message == make_message('message', channel, 'test message') - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster # see: https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html # #known-limitations-with-pubsub def test_unicode_pattern_message_handler(self, r): @@ -405,7 +404,7 @@ def test_channel_publish(self, r): self.channel, self.data) - @skip_if_cluster_mode() + @pytest.mark.onlynoncluster def test_pattern_publish(self, r): p = r.pubsub() p.psubscribe(self.pattern) @@ -534,7 +533,7 @@ def test_send_pubsub_ping_message(self, r): pattern=None) -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster class TestPubSubConnectionKilled: @skip_if_server_version_lt('3.0.0') diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 46a684e36d..25abee823c 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -1,9 +1,6 @@ import pytest from redis import exceptions -from .conftest import ( - skip_if_cluster_mode, -) multiply_script = """ local value = redis.call('GET', KEYS[1]) @@ -22,7 +19,7 @@ """ -@skip_if_cluster_mode() +@pytest.mark.onlynoncluster class TestScripting: @pytest.fixture(autouse=True) def reset_scripts(self, r): diff --git a/tests/test_search.py b/tests/test_search.py index 926b5ff3af..1cb04c820c 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -82,8 +82,8 @@ def createIndex(client, num_docs=100, definition=None): try: client.create_index( (TextField("play", weight=5.0), - TextField("txt"), - NumericField("chapter")), + TextField("txt"), + NumericField("chapter")), definition=definition, ) except redis.ResponseError: @@ -129,1091 +129,1049 @@ def client(modclient): @pytest.mark.redismod -def test_client(client): - num_docs = 500 - createIndex(client.ft(), num_docs=num_docs) - waitForIndex(client, "idx") - # verify info - info = client.ft().info() - for k in [ - "index_name", - "index_options", - "attributes", - "num_docs", - "max_doc_id", - "num_terms", - "num_records", - "inverted_sz_mb", - "offset_vectors_sz_mb", - "doc_table_size_mb", - "key_table_size_mb", - "records_per_doc_avg", - "bytes_per_record_avg", - "offsets_per_term_avg", - "offset_bits_per_record_avg", - ]: - assert k in info - - assert client.ft().index_name == info["index_name"] - assert num_docs == int(info["num_docs"]) - - res = client.ft().search("henry iv") - assert isinstance(res, Result) - assert 225 == res.total - assert 10 == len(res.docs) - assert res.duration > 0 - - for doc in res.docs: - assert doc.id - assert doc.play == "Henry IV" +class TestSearch: + def test_client(self, client): + num_docs = 500 + createIndex(client.ft(), num_docs=num_docs) + waitForIndex(client, "idx") + # verify info + info = client.ft().info() + for k in [ + "index_name", + "index_options", + "attributes", + "num_docs", + "max_doc_id", + "num_terms", + "num_records", + "inverted_sz_mb", + "offset_vectors_sz_mb", + "doc_table_size_mb", + "key_table_size_mb", + "records_per_doc_avg", + "bytes_per_record_avg", + "offsets_per_term_avg", + "offset_bits_per_record_avg", + ]: + assert k in info + + assert client.ft().index_name == info["index_name"] + assert num_docs == int(info["num_docs"]) + + res = client.ft().search("henry iv") + assert isinstance(res, Result) + assert 225 == res.total + assert 10 == len(res.docs) + assert res.duration > 0 + + for doc in res.docs: + assert doc.id + assert doc.play == "Henry IV" + assert len(doc.txt) > 0 + + # test no content + res = client.ft().search(Query("king").no_content()) + assert 194 == res.total + assert 10 == len(res.docs) + for doc in res.docs: + assert "txt" not in doc.__dict__ + assert "play" not in doc.__dict__ + + # test verbatim vs no verbatim + total = client.ft().search(Query("kings").no_content()).total + vtotal = client.ft().search( + Query("kings").no_content().verbatim()).total + assert total > vtotal + + # test in fields + txt_total = ( + client.ft().search( + Query("henry").no_content().limit_fields("txt")).total + ) + play_total = ( + client.ft().search( + Query("henry").no_content().limit_fields("play")).total + ) + both_total = ( + client.ft() + .search( + Query("henry").no_content().limit_fields("play", "txt")) + .total + ) + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = client.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" assert len(doc.txt) > 0 - # test no content - res = client.ft().search(Query("king").no_content()) - assert 194 == res.total - assert 10 == len(res.docs) - for doc in res.docs: - assert "txt" not in doc.__dict__ - assert "play" not in doc.__dict__ - - # test verbatim vs no verbatim - total = client.ft().search(Query("kings").no_content()).total - vtotal = client.ft().search(Query("kings").no_content().verbatim()).total - assert total > vtotal - - # test in fields - txt_total = ( - client.ft().search( - Query("henry").no_content().limit_fields("txt")).total - ) - play_total = ( - client.ft().search( - Query("henry").no_content().limit_fields("play")).total - ) - both_total = ( - client.ft() - .search(Query("henry").no_content().limit_fields("play", "txt")) - .total - ) - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = client.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" - assert len(doc.txt) > 0 - - # test in-keys - ids = [x.id for x in client.ft().search(Query("henry")).docs] - assert 10 == len(ids) - subset = ids[:5] - docs = client.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs.total - ids = [x.id for x in docs.docs] - assert set(ids) == set(subset) - - # test slop and in order - assert 193 == client.ft().search(Query("henry king")).total - assert 3 == client.ft().search( - Query("henry king").slop(0).in_order()).total - assert 52 == client.ft().search( - Query("king henry").slop(0).in_order()).total - assert 53 == client.ft().search(Query("henry king").slop(0)).total - assert 167 == client.ft().search(Query("henry king").slop(100)).total - - # test delete document - client.ft().add_document("doc-5ghs2", play="Death of a Salesman") - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - - assert 1 == client.ft().delete_document("doc-5ghs2") - res = client.ft().search(Query("death of a salesman")) - assert 0 == res.total - assert 0 == client.ft().delete_document("doc-5ghs2") - - client.ft().add_document("doc-5ghs2", play="Death of a Salesman") - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - client.ft().delete_document("doc-5ghs2") - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_payloads(client): - client.ft().create_index((TextField("txt"),)) - - client.ft().add_document("doc1", payload="foo baz", txt="foo bar") - client.ft().add_document("doc2", txt="foo bar") - - q = Query("foo bar").with_payloads() - res = client.ft().search(q) - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id - assert "foo baz" == res.docs[0].payload - assert res.docs[1].payload is None - - -@pytest.mark.redismod -def test_scores(client): - client.ft().create_index((TextField("txt"),)) - - client.ft().add_document("doc1", txt="foo baz") - client.ft().add_document("doc2", txt="foo bar") - - q = Query("foo ~bar").with_scores() - res = client.ft().search(q) - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 3.0 == res.docs[0].score - assert "doc1" == res.docs[1].id - # todo: enable once new RS version is tagged - # self.assertEqual(0.2, res.docs[1].score) - - -@pytest.mark.redismod -def test_replace(client): - client.ft().create_index((TextField("txt"),)) - - client.ft().add_document("doc1", txt="foo bar") - client.ft().add_document("doc2", txt="foo bar") - waitForIndex(client, "idx") - - res = client.ft().search("foo bar") - assert 2 == res.total - client.ft().add_document( - "doc1", - replace=True, - txt="this is a replaced doc" - ) - - res = client.ft().search("foo bar") - assert 1 == res.total - assert "doc2" == res.docs[0].id - - res = client.ft().search("replaced doc") - assert 1 == res.total - assert "doc1" == res.docs[0].id - - -@pytest.mark.redismod -def test_stopwords(client): - client.ft().create_index( - (TextField("txt"),), - stopwords=["foo", "bar", "baz"] - ) - client.ft().add_document("doc1", txt="foo bar") - client.ft().add_document("doc2", txt="hello world") - waitForIndex(client, "idx") - - q1 = Query("foo bar").no_content() - q2 = Query("foo bar hello world").no_content() - res1, res2 = client.ft().search(q1), client.ft().search(q2) - assert 0 == res1.total - assert 1 == res2.total - - -@pytest.mark.redismod -def test_filters(client): - client.ft().create_index( - (TextField("txt"), - NumericField("num"), - GeoField("loc")) - ) - client.ft().add_document( - "doc1", - txt="foo bar", - num=3.141, - loc="-0.441,51.458" - ) - client.ft().add_document("doc2", txt="foo baz", num=2, loc="-0.1,51.2") - - waitForIndex(client, "idx") - # Test numerical filter - q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() - q2 = ( - Query("foo") - .add_filter( - NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) - .no_content() - ) - res1, res2 = client.ft().search(q1), client.ft().search(q2) - - assert 1 == res1.total - assert 1 == res2.total - assert "doc2" == res1.docs[0].id - assert "doc1" == res2.docs[0].id - - # Test geo filter - q1 = Query("foo").add_filter( - GeoFilter("loc", -0.44, 51.45, 10)).no_content() - q2 = Query("foo").add_filter( - GeoFilter("loc", -0.44, 51.45, 100)).no_content() - res1, res2 = client.ft().search(q1), client.ft().search(q2) - - assert 1 == res1.total - assert 2 == res2.total - assert "doc1" == res1.docs[0].id - - # Sort results, after RDB reload order may change - res = [res2.docs[0].id, res2.docs[1].id] - res.sort() - assert ["doc1", "doc2"] == res - - -@pytest.mark.redismod -def test_payloads_with_no_content(client): - client.ft().create_index((TextField("txt"),)) - client.ft().add_document("doc1", payload="foo baz", txt="foo bar") - client.ft().add_document("doc2", payload="foo baz2", txt="foo bar") - - q = Query("foo bar").with_payloads().no_content() - res = client.ft().search(q) - assert 2 == len(res.docs) - - -@pytest.mark.redismod -def test_sort_by(client): - client.ft().create_index( - (TextField("txt"), - NumericField("num", sortable=True)) - ) - client.ft().add_document("doc1", txt="foo bar", num=1) - client.ft().add_document("doc2", txt="foo baz", num=2) - client.ft().add_document("doc3", txt="foo qux", num=3) - - # Test sort - q1 = Query("foo").sort_by("num", asc=True).no_content() - q2 = Query("foo").sort_by("num", asc=False).no_content() - res1, res2 = client.ft().search(q1), client.ft().search(q2) - - assert 3 == res1.total - assert "doc1" == res1.docs[0].id - assert "doc2" == res1.docs[1].id - assert "doc3" == res1.docs[2].id - assert 3 == res2.total - assert "doc1" == res2.docs[2].id - assert "doc2" == res2.docs[1].id - assert "doc3" == res2.docs[0].id - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_drop_index(): - """ - Ensure the index gets dropped by data remains by default - """ - for x in range(20): - for keep_docs in [[True, {}], [False, {"name": "haveit"}]]: - idx = "HaveIt" - index = getClient() - index.hset("index:haveit", mapping={"name": "haveit"}) - idef = IndexDefinition(prefix=["index:"]) - index.ft(idx).create_index((TextField("name"),), definition=idef) - waitForIndex(index, idx) - index.ft(idx).dropindex(delete_documents=keep_docs[0]) - i = index.hgetall("index:haveit") - assert i == keep_docs[1] - - -@pytest.mark.redismod -def test_example(client): - # Creating the index definition and schema - client.ft().create_index( - (TextField("title", weight=5.0), - TextField("body")) - ) - - # Indexing a document - client.ft().add_document( - "doc1", - title="RediSearch", - body="Redisearch impements a search engine on top of redis", - ) - - # Searching with complex parameters: - q = Query("search engine").verbatim().no_content().paging(0, 5) - - res = client.ft().search(q) - assert res is not None - - -@pytest.mark.redismod -def test_auto_complete(client): - n = 0 - with open(TITLES_CSV) as f: - cr = csv.reader(f) - - for row in cr: - n += 1 - term, score = row[0], float(row[1]) - assert n == client.ft().sugadd("ac", Suggestion(term, score=score)) - - assert n == client.ft().suglen("ac") - ret = client.ft().sugget("ac", "bad", with_scores=True) - assert 2 == len(ret) - assert "badger" == ret[0].string - assert isinstance(ret[0].score, float) - assert 1.0 != ret[0].score - assert "badalte rishtey" == ret[1].string - assert isinstance(ret[1].score, float) - assert 1.0 != ret[1].score - - ret = client.ft().sugget("ac", "bad", fuzzy=True, num=10) - assert 10 == len(ret) - assert 1.0 == ret[0].score - strs = {x.string for x in ret} - - for sug in strs: - assert 1 == client.ft().sugdel("ac", sug) - # make sure a second delete returns 0 - for sug in strs: - assert 0 == client.ft().sugdel("ac", sug) - - # make sure they were actually deleted - ret2 = client.ft().sugget("ac", "bad", fuzzy=True, num=10) - for sug in ret2: - assert sug.string not in strs - - # Test with payload - client.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) - client.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) - client.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) - - sugs = client.ft().sugget( - "ac", - "pay", - with_payloads=True, - with_scores=True - ) - assert 3 == len(sugs) - for sug in sugs: - assert sug.payload - assert sug.payload.startswith("pl") - - -@pytest.mark.redismod -def test_no_index(client): - client.ft().create_index( - ( - TextField("field"), - TextField("text", no_index=True, sortable=True), - NumericField("numeric", no_index=True, sortable=True), - GeoField("geo", no_index=True, sortable=True), - TagField("tag", no_index=True, sortable=True), + # test in-keys + ids = [x.id for x in client.ft().search(Query("henry")).docs] + assert 10 == len(ids) + subset = ids[:5] + docs = client.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs.total + ids = [x.id for x in docs.docs] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == client.ft().search(Query("henry king")).total + assert 3 == client.ft().search( + Query("henry king").slop(0).in_order()).total + assert 52 == client.ft().search( + Query("king henry").slop(0).in_order()).total + assert 53 == client.ft().search(Query("henry king").slop(0)).total + assert 167 == client.ft().search(Query("henry king").slop(100)).total + + # test delete document + client.ft().add_document("doc-5ghs2", play="Death of a Salesman") + res = client.ft().search(Query("death of a salesman")) + assert 1 == res.total + + assert 1 == client.ft().delete_document("doc-5ghs2") + res = client.ft().search(Query("death of a salesman")) + assert 0 == res.total + assert 0 == client.ft().delete_document("doc-5ghs2") + + client.ft().add_document("doc-5ghs2", play="Death of a Salesman") + res = client.ft().search(Query("death of a salesman")) + assert 1 == res.total + client.ft().delete_document("doc-5ghs2") + + @skip_ifmodversion_lt("2.2.0", "search") + def test_payloads(self, client): + client.ft().create_index((TextField("txt"),)) + + client.ft().add_document("doc1", payload="foo baz", txt="foo bar") + client.ft().add_document("doc2", txt="foo bar") + + q = Query("foo bar").with_payloads() + res = client.ft().search(q) + assert 2 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + assert "foo baz" == res.docs[0].payload + assert res.docs[1].payload is None + + def test_scores(self, client): + client.ft().create_index((TextField("txt"),)) + + client.ft().add_document("doc1", txt="foo baz") + client.ft().add_document("doc2", txt="foo bar") + + q = Query("foo ~bar").with_scores() + res = client.ft().search(q) + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 3.0 == res.docs[0].score + assert "doc1" == res.docs[1].id + # todo: enable once new RS version is tagged + # self.assertEqual(0.2, res.docs[1].score) + + def test_replace(self, client): + client.ft().create_index((TextField("txt"),)) + + client.ft().add_document("doc1", txt="foo bar") + client.ft().add_document("doc2", txt="foo bar") + waitForIndex(client, "idx") + + res = client.ft().search("foo bar") + assert 2 == res.total + client.ft().add_document( + "doc1", + replace=True, + txt="this is a replaced doc" ) - ) - - client.ft().add_document( - "doc1", field="aaa", text="1", numeric="1", geo="1,1", tag="1" - ) - client.ft().add_document( - "doc2", field="aab", text="2", numeric="2", geo="2,2", tag="2" - ) - waitForIndex(client, "idx") - - res = client.ft().search(Query("@text:aa*")) - assert 0 == res.total - - res = client.ft().search(Query("@field:aa*")) - assert 2 == res.total - res = client.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res.total - assert "doc2" == res.docs[0].id + res = client.ft().search("foo bar") + assert 1 == res.total + assert "doc2" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res.docs[0].id - - res = client.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res.docs[0].id - - res = client.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res.docs[0].id - - res = client.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res.docs[0].id - - # Ensure exception is raised for non-indexable, non-sortable fields - with pytest.raises(Exception): - TextField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - NumericField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - GeoField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - TagField("name", no_index=True, sortable=False) - - -@pytest.mark.redismod -def test_partial(client): - client.ft().create_index( - (TextField("f1"), - TextField("f2"), - TextField("f3")) - ) - client.ft().add_document("doc1", f1="f1_val", f2="f2_val") - client.ft().add_document("doc2", f1="f1_val", f2="f2_val") - client.ft().add_document("doc1", f3="f3_val", partial=True) - client.ft().add_document("doc2", f3="f3_val", replace=True) - waitForIndex(client, "idx") + res = client.ft().search("replaced doc") + assert 1 == res.total + assert "doc1" == res.docs[0].id - # Search for f3 value. All documents should have it - res = client.ft().search("@f3:f3_val") - assert 2 == res.total - - # Only the document updated with PARTIAL should still have f1 and f2 values - res = client.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") - assert 1 == res.total - - -@pytest.mark.redismod -def test_no_create(client): - client.ft().create_index( - (TextField("f1"), - TextField("f2"), - TextField("f3")) - ) - client.ft().add_document("doc1", f1="f1_val", f2="f2_val") - client.ft().add_document("doc2", f1="f1_val", f2="f2_val") - client.ft().add_document("doc1", f3="f3_val", no_create=True) - client.ft().add_document("doc2", f3="f3_val", no_create=True, partial=True) - waitForIndex(client, "idx") - - # Search for f3 value. All documents should have it - res = client.ft().search("@f3:f3_val") - assert 2 == res.total - - # Only the document updated with PARTIAL should still have f1 and f2 values - res = client.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") - assert 1 == res.total - - with pytest.raises(redis.ResponseError): + def test_stopwords(self, client): + client.ft().create_index( + (TextField("txt"),), + stopwords=["foo", "bar", "baz"] + ) + client.ft().add_document("doc1", txt="foo bar") + client.ft().add_document("doc2", txt="hello world") + waitForIndex(client, "idx") + + q1 = Query("foo bar").no_content() + q2 = Query("foo bar hello world").no_content() + res1, res2 = client.ft().search(q1), client.ft().search(q2) + assert 0 == res1.total + assert 1 == res2.total + + def test_filters(self, client): + client.ft().create_index( + (TextField("txt"), + NumericField("num"), + GeoField("loc")) + ) client.ft().add_document( - "doc3", - f2="f2_val", - f3="f3_val", - no_create=True + "doc1", + txt="foo bar", + num=3.141, + loc="-0.441,51.458" + ) + client.ft().add_document("doc2", txt="foo baz", num=2, loc="-0.1,51.2") + + waitForIndex(client, "idx") + # Test numerical filter + q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() + q2 = ( + Query("foo") + .add_filter( + NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) + .no_content() + ) + res1, res2 = client.ft().search(q1), client.ft().search(q2) + + assert 1 == res1.total + assert 1 == res2.total + assert "doc2" == res1.docs[0].id + assert "doc1" == res2.docs[0].id + + # Test geo filter + q1 = Query("foo").add_filter( + GeoFilter("loc", -0.44, 51.45, 10)).no_content() + q2 = Query("foo").add_filter( + GeoFilter("loc", -0.44, 51.45, 100)).no_content() + res1, res2 = client.ft().search(q1), client.ft().search(q2) + + assert 1 == res1.total + assert 2 == res2.total + assert "doc1" == res1.docs[0].id + + # Sort results, after RDB reload order may change + res = [res2.docs[0].id, res2.docs[1].id] + res.sort() + assert ["doc1", "doc2"] == res + + def test_payloads_with_no_content(self, client): + client.ft().create_index((TextField("txt"),)) + client.ft().add_document("doc1", payload="foo baz", txt="foo bar") + client.ft().add_document("doc2", payload="foo baz2", txt="foo bar") + + q = Query("foo bar").with_payloads().no_content() + res = client.ft().search(q) + assert 2 == len(res.docs) + + def test_sort_by(self, client): + client.ft().create_index( + (TextField("txt"), + NumericField("num", sortable=True)) + ) + client.ft().add_document("doc1", txt="foo bar", num=1) + client.ft().add_document("doc2", txt="foo baz", num=2) + client.ft().add_document("doc3", txt="foo qux", num=3) + + # Test sort + q1 = Query("foo").sort_by("num", asc=True).no_content() + q2 = Query("foo").sort_by("num", asc=False).no_content() + res1, res2 = client.ft().search(q1), client.ft().search(q2) + + assert 3 == res1.total + assert "doc1" == res1.docs[0].id + assert "doc2" == res1.docs[1].id + assert "doc3" == res1.docs[2].id + assert 3 == res2.total + assert "doc1" == res2.docs[2].id + assert "doc2" == res2.docs[1].id + assert "doc3" == res2.docs[0].id + + @skip_ifmodversion_lt("2.0.0", "search") + def test_drop_index(self): + """ + Ensure the index gets dropped by data remains by default + """ + for x in range(20): + for keep_docs in [[True, {}], [False, {"name": "haveit"}]]: + idx = "HaveIt" + index = getClient() + index.hset("index:haveit", mapping={"name": "haveit"}) + idef = IndexDefinition(prefix=["index:"]) + index.ft(idx).create_index((TextField("name"),), + definition=idef) + waitForIndex(index, idx) + index.ft(idx).dropindex(delete_documents=keep_docs[0]) + i = index.hgetall("index:haveit") + assert i == keep_docs[1] + + def test_example(self, client): + # Creating the index definition and schema + client.ft().create_index( + (TextField("title", weight=5.0), + TextField("body")) ) + # Indexing a document + client.ft().add_document( + "doc1", + title="RediSearch", + body="Redisearch impements a search engine on top of redis", + ) -@pytest.mark.redismod -def test_explain(client): - client.ft().create_index( - (TextField("f1"), - TextField("f2"), - TextField("f3")) - ) - res = client.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") - assert res - - -@pytest.mark.redismod -def test_summarize(client): - createIndex(client.ft()) - waitForIndex(client, "idx") - - q = Query("king henry").paging(0, 1) - q.highlight(fields=("play", "txt"), tags=("", "")) - q.summarize("txt") - - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry IV" == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) - - q = Query("king henry").paging(0, 1).summarize().highlight() - - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry ... " == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa - == doc.txt - ) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_alias(): - index1 = getClient() - index2 = getClient() - - index1.hset("index1:lonestar", mapping={"name": "lonestar"}) - index2.hset("index2:yogurt", mapping={"name": "yogurt"}) - - if os.environ.get("GITHUB_WORKFLOW", None) is not None: - time.sleep(2) - else: - time.sleep(5) - - def1 = IndexDefinition(prefix=["index1:"], score_field="name") - def2 = IndexDefinition(prefix=["index2:"], score_field="name") - - ftindex1 = index1.ft("testAlias") - ftindex2 = index1.ft("testAlias2") - ftindex1.create_index((TextField("name"),), definition=def1) - ftindex2.create_index((TextField("name"),), definition=def2) - - # CI is slower - try: - res = ftindex1.search("*").docs[0] - except IndexError: - time.sleep(5) - res = ftindex1.search("*").docs[0] - assert "index1:lonestar" == res.id - - # create alias and check for results - ftindex1.aliasadd("spaceballs") - alias_client = getClient().ft("spaceballs") - res = alias_client.search("*").docs[0] - assert "index1:lonestar" == res.id - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - ftindex2.aliasadd("spaceballs") - - # update alias and ensure new results - ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient().ft("spaceballs") - res = alias_client2.search("*").docs[0] - assert "index2:yogurt" == res.id - - ftindex2.aliasdel("spaceballs") - with pytest.raises(Exception): - alias_client2.search("*").docs[0] - - -@pytest.mark.redismod -def test_alias_basic(): - # Creating a client with one index - getClient().flushdb() - index1 = getClient().ft("testAlias") - - index1.create_index((TextField("txt"),)) - index1.add_document("doc1", txt="text goes here") - - index2 = getClient().ft("testAlias2") - index2.create_index((TextField("txt"),)) - index2.add_document("doc2", txt="text goes here") - - # add the actual alias and check - index1.aliasadd("myalias") - alias_client = getClient().ft("myalias") - res = sorted(alias_client.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - index2.aliasadd("myalias") - - # update the alias and ensure we get doc2 - index2.aliasupdate("myalias") - alias_client2 = getClient().ft("myalias") - res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id - - # delete the alias and expect an error if we try to query again - index2.aliasdel("myalias") - with pytest.raises(Exception): - _ = alias_client2.search("*").docs[0] - - -@pytest.mark.redismod -def test_tags(client): - client.ft().create_index((TextField("txt"), TagField("tags"))) - tags = "foo,foo bar,hello;world" - tags2 = "soba,ramen" - - client.ft().add_document("doc1", txt="fooz barz", tags=tags) - client.ft().add_document("doc2", txt="noodles", tags=tags2) - waitForIndex(client, "idx") - - q = Query("@tags:{foo}") - res = client.ft().search(q) - assert 1 == res.total - - q = Query("@tags:{foo bar}") - res = client.ft().search(q) - assert 1 == res.total - - q = Query("@tags:{foo\\ bar}") - res = client.ft().search(q) - assert 1 == res.total - - q = Query("@tags:{hello\\;world}") - res = client.ft().search(q) - assert 1 == res.total - - q2 = client.ft().tagvals("tags") - assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() - - -@pytest.mark.redismod -def test_textfield_sortable_nostem(client): - # Creating the index definition with sortable and no_stem - client.ft().create_index((TextField("txt", sortable=True, no_stem=True),)) - - # Now get the index info to confirm its contents - response = client.ft().info() - assert "SORTABLE" in response["attributes"][0] - assert "NOSTEM" in response["attributes"][0] - - -@pytest.mark.redismod -def test_alter_schema_add(client): - # Creating the index definition and schema - client.ft().create_index(TextField("title")) - - # Using alter to add a field - client.ft().alter_schema_add(TextField("body")) - - # Indexing a document - client.ft().add_document( - "doc1", title="MyTitle", body="Some content only in the body" - ) - - # Searching with parameter only in the body (the added field) - q = Query("only in the body") - - # Ensure we find the result searching on the added body field - res = client.ft().search(q) - assert 1 == res.total - - -@pytest.mark.redismod -def test_spell_check(client): - client.ft().create_index((TextField("f1"), TextField("f2"))) - - client.ft().add_document( - "doc1", - f1="some valid content", - f2="this is sample text" - ) - client.ft().add_document("doc2", f1="very important", f2="lorem ipsum") - waitForIndex(client, "idx") - - # test spellcheck - res = client.ft().spellcheck("impornant") - assert "important" == res["impornant"][0]["suggestion"] - - res = client.ft().spellcheck("contnt") - assert "content" == res["contnt"][0]["suggestion"] - - # test spellcheck with Levenshtein distance - res = client.ft().spellcheck("vlis") - assert res == {} - res = client.ft().spellcheck("vlis", distance=2) - assert "valid" == res["vlis"][0]["suggestion"] - - # test spellcheck include - client.ft().dict_add("dict", "lore", "lorem", "lorm") - res = client.ft().spellcheck("lorm", include="dict") - assert len(res["lorm"]) == 3 - assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") - assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") - - # test spellcheck exclude - res = client.ft().spellcheck("lorm", exclude="dict") - assert res == {} - - -@pytest.mark.redismod -def test_dict_operations(client): - client.ft().create_index((TextField("f1"), TextField("f2"))) - # Add three items - res = client.ft().dict_add("custom_dict", "item1", "item2", "item3") - assert 3 == res - - # Remove one item - res = client.ft().dict_del("custom_dict", "item2") - assert 1 == res - - # Dump dict and inspect content - res = client.ft().dict_dump("custom_dict") - assert ["item1", "item3"] == res - - # Remove rest of the items before reload - client.ft().dict_del("custom_dict", *res) - - -@pytest.mark.redismod -def test_phonetic_matcher(client): - client.ft().create_index((TextField("name"),)) - client.ft().add_document("doc1", name="Jon") - client.ft().add_document("doc2", name="John") - - res = client.ft().search(Query("Jon")) - assert 1 == len(res.docs) - assert "Jon" == res.docs[0].name - - # Drop and create index with phonetic matcher - client.flushdb() - - client.ft().create_index((TextField("name", phonetic_matcher="dm:en"),)) - client.ft().add_document("doc1", name="Jon") - client.ft().add_document("doc2", name="John") - - res = client.ft().search(Query("Jon")) - assert 2 == len(res.docs) - assert ["John", "Jon"] == sorted([d.name for d in res.docs]) - - -@pytest.mark.redismod -def test_scorer(client): - client.ft().create_index((TextField("description"),)) - - client.ft().add_document( - "doc1", description="The quick brown fox jumps over the lazy dog" - ) - client.ft().add_document( - "doc2", - description="Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.", # noqa - ) + # Searching with complex parameters: + q = Query("search engine").verbatim().no_content().paging(0, 5) + + res = client.ft().search(q) + assert res is not None + + def test_auto_complete(self, client): + n = 0 + with open(TITLES_CSV) as f: + cr = csv.reader(f) + + for row in cr: + n += 1 + term, score = row[0], float(row[1]) + assert n == client.ft().sugadd("ac", + Suggestion(term, score=score)) + + assert n == client.ft().suglen("ac") + ret = client.ft().sugget("ac", "bad", with_scores=True) + assert 2 == len(ret) + assert "badger" == ret[0].string + assert isinstance(ret[0].score, float) + assert 1.0 != ret[0].score + assert "badalte rishtey" == ret[1].string + assert isinstance(ret[1].score, float) + assert 1.0 != ret[1].score + + ret = client.ft().sugget("ac", "bad", fuzzy=True, num=10) + assert 10 == len(ret) + assert 1.0 == ret[0].score + strs = {x.string for x in ret} + + for sug in strs: + assert 1 == client.ft().sugdel("ac", sug) + # make sure a second delete returns 0 + for sug in strs: + assert 0 == client.ft().sugdel("ac", sug) + + # make sure they were actually deleted + ret2 = client.ft().sugget("ac", "bad", fuzzy=True, num=10) + for sug in ret2: + assert sug.string not in strs + + # Test with payload + client.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) + client.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) + client.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) + + sugs = client.ft().sugget( + "ac", + "pay", + with_payloads=True, + with_scores=True + ) + assert 3 == len(sugs) + for sug in sugs: + assert sug.payload + assert sug.payload.startswith("pl") + + def test_no_index(self, client): + client.ft().create_index( + ( + TextField("field"), + TextField("text", no_index=True, sortable=True), + NumericField("numeric", no_index=True, sortable=True), + GeoField("geo", no_index=True, sortable=True), + TagField("tag", no_index=True, sortable=True), + ) + ) - # default scorer is TFIDF - res = client.ft().search(Query("quick").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search( - Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - assert 0.1111111111111111 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.17699114465425977 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res.docs[0].score + client.ft().add_document( + "doc1", field="aaa", text="1", numeric="1", geo="1,1", tag="1" + ) + client.ft().add_document( + "doc2", field="aab", text="2", numeric="2", geo="2,2", tag="2" + ) + waitForIndex(client, "idx") + res = client.ft().search(Query("@text:aa*")) + assert 0 == res.total -@pytest.mark.redismod -def test_get(client): - client.ft().create_index((TextField("f1"), TextField("f2"))) + res = client.ft().search(Query("@field:aa*")) + assert 2 == res.total - assert [None] == client.ft().get("doc1") - assert [None, None] == client.ft().get("doc2", "doc1") + res = client.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res.total + assert "doc2" == res.docs[0].id - client.ft().add_document( - "doc1", f1="some valid content dd1", f2="this is sample text ff1" - ) - client.ft().add_document( - "doc2", f1="some valid content dd2", f2="this is sample text ff2" - ) + res = client.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res.docs[0].id - assert [ - ["f1", "some valid content dd2", "f2", "this is sample text ff2"] - ] == client.ft().get("doc2") - assert [ - ["f1", "some valid content dd1", "f2", "this is sample text ff1"], - ["f1", "some valid content dd2", "f2", "this is sample text ff2"], - ] == client.ft().get("doc1", "doc2") + res = client.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res.docs[0].id + res = client.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res.docs[0].id -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_config(client): - assert client.ft().config_set("TIMEOUT", "100") - with pytest.raises(redis.ResponseError): - client.ft().config_set("TIMEOUT", "null") - res = client.ft().config_get("*") - assert "100" == res["TIMEOUT"] - res = client.ft().config_get("TIMEOUT") - assert "100" == res["TIMEOUT"] + res = client.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res.docs[0].id + # Ensure exception is raised for non-indexable, non-sortable fields + with pytest.raises(Exception): + TextField("name", no_index=True, sortable=False) + with pytest.raises(Exception): + NumericField("name", no_index=True, sortable=False) + with pytest.raises(Exception): + GeoField("name", no_index=True, sortable=False) + with pytest.raises(Exception): + TagField("name", no_index=True, sortable=False) -@pytest.mark.redismod -def test_aggregations(client): - # Creating the index definition and schema - client.ft().create_index( - ( - NumericField("random_num"), - TextField("title"), - TextField("body"), - TextField("parent"), + def test_partial(self, client): + client.ft().create_index( + (TextField("f1"), + TextField("f2"), + TextField("f3")) + ) + client.ft().add_document("doc1", f1="f1_val", f2="f2_val") + client.ft().add_document("doc2", f1="f1_val", f2="f2_val") + client.ft().add_document("doc1", f3="f3_val", partial=True) + client.ft().add_document("doc2", f3="f3_val", replace=True) + waitForIndex(client, "idx") + + # Search for f3 value. All documents should have it + res = client.ft().search("@f3:f3_val") + assert 2 == res.total + + # Only the document updated with PARTIAL should still have f1 and f2 + # values + res = client.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") + assert 1 == res.total + + def test_no_create(self, client): + client.ft().create_index( + (TextField("f1"), + TextField("f2"), + TextField("f3")) + ) + client.ft().add_document("doc1", f1="f1_val", f2="f2_val") + client.ft().add_document("doc2", f1="f1_val", f2="f2_val") + client.ft().add_document("doc1", f3="f3_val", no_create=True) + client.ft().add_document("doc2", f3="f3_val", no_create=True, + partial=True) + waitForIndex(client, "idx") + + # Search for f3 value. All documents should have it + res = client.ft().search("@f3:f3_val") + assert 2 == res.total + + # Only the document updated with PARTIAL should still have f1 and f2 + # values + res = client.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") + assert 1 == res.total + + with pytest.raises(redis.ResponseError): + client.ft().add_document( + "doc3", + f2="f2_val", + f3="f3_val", + no_create=True + ) + + def test_explain(self, client): + client.ft().create_index( + (TextField("f1"), + TextField("f2"), + TextField("f3")) + ) + res = client.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") + assert res + + def test_summarize(self, client): + createIndex(client.ft()) + waitForIndex(client, "idx") + + q = Query("king henry").paging(0, 1) + q.highlight(fields=("play", "txt"), tags=("", "")) + q.summarize("txt") + + doc = sorted(client.ft().search(q).docs)[0] + assert "Henry IV" == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING " + "HENRY, LORD JOHN OF LANCASTER, the EARL of " + "WESTMORELAND, SIR... " # noqa + == doc.txt ) - ) - - # Indexing a document - client.ft().add_document( - "search", - title="RediSearch", - body="Redisearch impements a search engine on top of redis", - parent="redis", - random_num=10, - ) - client.ft().add_document( - "ai", - title="RedisAI", - body="RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa - parent="redis", - random_num=3, - ) - client.ft().add_document( - "json", - title="RedisJson", - body="RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa - parent="redis", - random_num=8, - ) - - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.count(), - reducers.count_distinct("@title"), - reducers.count_distinctish("@title"), - reducers.sum("@random_num"), - reducers.min("@random_num"), - reducers.max("@random_num"), - reducers.avg("@random_num"), - reducers.stddev("random_num"), - reducers.quantile("@random_num", 0.5), - reducers.tolist("@title"), - reducers.first_value("@title"), - reducers.random_sample("@title", 2), - ) - - res = client.ft().aggregate(req) - - res = res.rows[0] - assert len(res) == 26 - assert "redis" == res[1] - assert "3" == res[3] - assert "3" == res[5] - assert "3" == res[7] - assert "21" == res[9] - assert "3" == res[11] - assert "10" == res[13] - assert "7" == res[15] - assert "3.60555127546" == res[17] - assert "10" == res[19] - assert ["RediSearch", "RedisAI", "RedisJson"] == res[21] - assert "RediSearch" == res[23] - assert 2 == len(res[25]) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_index_definition(client): - """ - Create definition and test its args - """ - with pytest.raises(RuntimeError): - IndexDefinition(prefix=["hset:", "henry"], index_type="json") - - definition = IndexDefinition( - prefix=["hset:", "henry"], - filter="@f1==32", - language="English", - language_field="play", - score_field="chapter", - score=0.5, - payload_field="txt", - index_type=IndexType.JSON, - ) - - assert [ - "ON", - "JSON", - "PREFIX", - 2, - "hset:", - "henry", - "FILTER", - "@f1==32", - "LANGUAGE_FIELD", - "play", - "LANGUAGE", - "English", - "SCORE_FIELD", - "chapter", - "SCORE", - 0.5, - "PAYLOAD_FIELD", - "txt", - ] == definition.args - - createIndex(client.ft(), num_docs=500, definition=definition) - - -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_create_client_definition(client): - """ - Create definition with no index type provided, - and use hset to test the client definition (the default is HASH). - """ - definition = IndexDefinition(prefix=["hset:", "henry"]) - createIndex(client.ft(), num_docs=500, definition=definition) - info = client.ft().info() - assert 494 == int(info["num_docs"]) + q = Query("king henry").paging(0, 1).summarize().highlight() - client.ft().client.hset("hset:1", "f1", "v1") - info = client.ft().info() - assert 495 == int(info["num_docs"]) + doc = sorted(client.ft().search(q).docs)[0] + assert "Henry ... " == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING " + "HENRY, LORD JOHN OF LANCASTER, the EARL of " + "WESTMORELAND, SIR... " # noqa + == doc.txt + ) + @skip_ifmodversion_lt("2.0.0", "search") + def test_alias(self): + index1 = getClient() + index2 = getClient() -@pytest.mark.redismod -@skip_ifmodversion_lt("2.0.0", "search") -def test_create_client_definition_hash(client): - """ - Create definition with IndexType.HASH as index type (ON HASH), - and use hset to test the client definition. - """ - definition = IndexDefinition( - prefix=["hset:", "henry"], - index_type=IndexType.HASH - ) - createIndex(client.ft(), num_docs=500, definition=definition) + index1.hset("index1:lonestar", mapping={"name": "lonestar"}) + index2.hset("index2:yogurt", mapping={"name": "yogurt"}) - info = client.ft().info() - assert 494 == int(info["num_docs"]) + if os.environ.get("GITHUB_WORKFLOW", None) is not None: + time.sleep(2) + else: + time.sleep(5) - client.ft().client.hset("hset:1", "f1", "v1") - info = client.ft().info() - assert 495 == int(info["num_docs"]) + def1 = IndexDefinition(prefix=["index1:"], score_field="name") + def2 = IndexDefinition(prefix=["index2:"], score_field="name") + ftindex1 = index1.ft("testAlias") + ftindex2 = index1.ft("testAlias2") + ftindex1.create_index((TextField("name"),), definition=def1) + ftindex2.create_index((TextField("name"),), definition=def2) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_create_client_definition_json(client): - """ - Create definition with IndexType.JSON as index type (ON JSON), - and use json client to test it. - """ - definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) - client.ft().create_index((TextField("$.name"),), definition=definition) + # CI is slower + try: + res = ftindex1.search("*").docs[0] + except IndexError: + time.sleep(5) + res = ftindex1.search("*").docs[0] + assert "index1:lonestar" == res.id + + # create alias and check for results + ftindex1.aliasadd("spaceballs") + alias_client = getClient().ft("spaceballs") + res = alias_client.search("*").docs[0] + assert "index1:lonestar" == res.id + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient().ft("spaceballs") + res = alias_client2.search("*").docs[0] + assert "index2:yogurt" == res.id + + ftindex2.aliasdel("spaceballs") + with pytest.raises(Exception): + alias_client2.search("*").docs[0] + + def test_alias_basic(self): + # Creating a client with one index + getClient().flushdb() + index1 = getClient().ft("testAlias") + + index1.create_index((TextField("txt"),)) + index1.add_document("doc1", txt="text goes here") + + index2 = getClient().ft("testAlias2") + index2.create_index((TextField("txt"),)) + index2.add_document("doc2", txt="text goes here") + + # add the actual alias and check + index1.aliasadd("myalias") + alias_client = getClient().ft("myalias") + res = sorted(alias_client.search("*").docs, key=lambda x: x.id) + assert "doc1" == res[0].id + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + index2.aliasupdate("myalias") + alias_client2 = getClient().ft("myalias") + res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) + assert "doc1" == res[0].id + + # delete the alias and expect an error if we try to query again + index2.aliasdel("myalias") + with pytest.raises(Exception): + _ = alias_client2.search("*").docs[0] + + def test_tags(self, client): + client.ft().create_index((TextField("txt"), TagField("tags"))) + tags = "foo,foo bar,hello;world" + tags2 = "soba,ramen" + + client.ft().add_document("doc1", txt="fooz barz", tags=tags) + client.ft().add_document("doc2", txt="noodles", tags=tags2) + waitForIndex(client, "idx") + + q = Query("@tags:{foo}") + res = client.ft().search(q) + assert 1 == res.total + + q = Query("@tags:{foo bar}") + res = client.ft().search(q) + assert 1 == res.total + + q = Query("@tags:{foo\\ bar}") + res = client.ft().search(q) + assert 1 == res.total + + q = Query("@tags:{hello\\;world}") + res = client.ft().search(q) + assert 1 == res.total + + q2 = client.ft().tagvals("tags") + assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() + + def test_textfield_sortable_nostem(self, client): + # Creating the index definition with sortable and no_stem + client.ft().create_index( + (TextField("txt", sortable=True, no_stem=True),)) + + # Now get the index info to confirm its contents + response = client.ft().info() + assert "SORTABLE" in response["attributes"][0] + assert "NOSTEM" in response["attributes"][0] + + def test_alter_schema_add(self, client): + # Creating the index definition and schema + client.ft().create_index(TextField("title")) + + # Using alter to add a field + client.ft().alter_schema_add(TextField("body")) + + # Indexing a document + client.ft().add_document( + "doc1", title="MyTitle", body="Some content only in the body" + ) - client.json().set("king:1", Path.rootPath(), {"name": "henry"}) - client.json().set("king:2", Path.rootPath(), {"name": "james"}) + # Searching with parameter only in the body (the added field) + q = Query("only in the body") - res = client.ft().search("henry") - assert res.docs[0].id == "king:1" - assert res.docs[0].payload is None - assert res.docs[0].json == '{"name":"henry"}' - assert res.total == 1 + # Ensure we find the result searching on the added body field + res = client.ft().search(q) + assert 1 == res.total + def test_spell_check(self, client): + client.ft().create_index((TextField("f1"), TextField("f2"))) -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_fields_as_name(client): - # create index - SCHEMA = ( - TextField("$.name", sortable=True, as_name="name"), - NumericField("$.age", as_name="just_a_number"), - ) - definition = IndexDefinition(index_type=IndexType.JSON) - client.ft().create_index(SCHEMA, definition=definition) - - # insert json data - res = client.json().set( - "doc:1", - Path.rootPath(), - {"name": "Jon", "age": 25} - ) - assert res + client.ft().add_document( + "doc1", + f1="some valid content", + f2="this is sample text" + ) + client.ft().add_document("doc2", f1="very important", f2="lorem ipsum") + waitForIndex(client, "idx") + + # test spellcheck + res = client.ft().spellcheck("impornant") + assert "important" == res["impornant"][0]["suggestion"] + + res = client.ft().spellcheck("contnt") + assert "content" == res["contnt"][0]["suggestion"] + + # test spellcheck with Levenshtein distance + res = client.ft().spellcheck("vlis") + assert res == {} + res = client.ft().spellcheck("vlis", distance=2) + assert "valid" == res["vlis"][0]["suggestion"] + + # test spellcheck include + client.ft().dict_add("dict", "lore", "lorem", "lorm") + res = client.ft().spellcheck("lorm", include="dict") + assert len(res["lorm"]) == 3 + assert ( + res["lorm"][0]["suggestion"], + res["lorm"][1]["suggestion"], + res["lorm"][2]["suggestion"], + ) == ("lorem", "lore", "lorm") + assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ( + "0.5", "0") + + # test spellcheck exclude + res = client.ft().spellcheck("lorm", exclude="dict") + assert res == {} + + def test_dict_operations(self, client): + client.ft().create_index((TextField("f1"), TextField("f2"))) + # Add three items + res = client.ft().dict_add("custom_dict", "item1", "item2", "item3") + assert 3 == res + + # Remove one item + res = client.ft().dict_del("custom_dict", "item2") + assert 1 == res + + # Dump dict and inspect content + res = client.ft().dict_dump("custom_dict") + assert ["item1", "item3"] == res + + # Remove rest of the items before reload + client.ft().dict_del("custom_dict", *res) + + def test_phonetic_matcher(self, client): + client.ft().create_index((TextField("name"),)) + client.ft().add_document("doc1", name="Jon") + client.ft().add_document("doc2", name="John") + + res = client.ft().search(Query("Jon")) + assert 1 == len(res.docs) + assert "Jon" == res.docs[0].name + + # Drop and create index with phonetic matcher + client.flushdb() + + client.ft().create_index( + (TextField("name", phonetic_matcher="dm:en"),)) + client.ft().add_document("doc1", name="Jon") + client.ft().add_document("doc2", name="John") + + res = client.ft().search(Query("Jon")) + assert 2 == len(res.docs) + assert ["John", "Jon"] == sorted([d.name for d in res.docs]) + + def test_scorer(self, client): + client.ft().create_index((TextField("description"),)) - total = client.ft().search( - Query("Jon").return_fields("name", "just_a_number")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "Jon" == total[0].name - assert "25" == total[0].just_a_number + client.ft().add_document( + "doc1", description="The quick brown fox jumps over the lazy dog" + ) + client.ft().add_document( + "doc2", + description="Quick alice was beginning to get very tired of " + "sitting by her quick sister on the bank, and of " + "having nothing to do.", + # noqa + ) + # default scorer is TFIDF + res = client.ft().search(Query("quick").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + assert 0.1111111111111111 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res.docs[0].score + res = client.ft().search( + Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search( + Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res.docs[0].score + + def test_get(self, client): + client.ft().create_index((TextField("f1"), TextField("f2"))) + + assert [None] == client.ft().get("doc1") + assert [None, None] == client.ft().get("doc2", "doc1") -@pytest.mark.redismod -@skip_ifmodversion_lt("2.2.0", "search") -def test_search_return_fields(client): - res = client.json().set( - "doc:1", - Path.rootPath(), - {"t": "riceratops", "t2": "telmatosaurus", "n": 9072, "flt": 97.2}, - ) - assert res + client.ft().add_document( + "doc1", f1="some valid content dd1", f2="this is sample text ff1" + ) + client.ft().add_document( + "doc2", f1="some valid content dd2", f2="this is sample text ff2" + ) - # create index on - definition = IndexDefinition(index_type=IndexType.JSON) - SCHEMA = ( - TextField("$.t"), - NumericField("$.flt"), - ) - client.ft().create_index(SCHEMA, definition=definition) - waitForIndex(client, "idx") + assert [ + ["f1", "some valid content dd2", "f2", + "this is sample text ff2"] + ] == client.ft().get("doc2") + assert [ + ["f1", "some valid content dd1", "f2", + "this is sample text ff1"], + ["f1", "some valid content dd2", "f2", + "this is sample text ff2"], + ] == client.ft().get("doc1", "doc2") + + @skip_ifmodversion_lt("2.2.0", "search") + def test_config(self, client): + assert client.ft().config_set("TIMEOUT", "100") + with pytest.raises(redis.ResponseError): + client.ft().config_set("TIMEOUT", "null") + res = client.ft().config_get("*") + assert "100" == res["TIMEOUT"] + res = client.ft().config_get("TIMEOUT") + assert "100" == res["TIMEOUT"] + + def test_aggregations(self, client): + # Creating the index definition and schema + client.ft().create_index( + ( + NumericField("random_num"), + TextField("title"), + TextField("body"), + TextField("parent"), + ) + ) - total = client.ft().search( - Query("*").return_field("$.t", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "riceratops" == total[0].txt + # Indexing a document + client.ft().add_document( + "search", + title="RediSearch", + body="Redisearch impements a search engine on top of redis", + parent="redis", + random_num=10, + ) + client.ft().add_document( + "ai", + title="RedisAI", + body="RedisAI executes Deep Learning/Machine Learning models and" + " managing their data.", + # noqa + parent="redis", + random_num=3, + ) + client.ft().add_document( + "json", + title="RedisJson", + body="RedisJSON implements ECMA-404 The JSON Data Interchange " + "Standard as a native data type.", + # noqa + parent="redis", + random_num=8, + ) - total = client.ft().search( - Query("*").return_field("$.t2", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "telmatosaurus" == total[0].txt + req = aggregations.AggregateRequest("redis").group_by( + "@parent", + reducers.count(), + reducers.count_distinct("@title"), + reducers.count_distinctish("@title"), + reducers.sum("@random_num"), + reducers.min("@random_num"), + reducers.max("@random_num"), + reducers.avg("@random_num"), + reducers.stddev("random_num"), + reducers.quantile("@random_num", 0.5), + reducers.tolist("@title"), + reducers.first_value("@title"), + reducers.random_sample("@title", 2), + ) + res = client.ft().aggregate(req) + + res = res.rows[0] + assert len(res) == 26 + assert "redis" == res[1] + assert "3" == res[3] + assert "3" == res[5] + assert "3" == res[7] + assert "21" == res[9] + assert "3" == res[11] + assert "10" == res[13] + assert "7" == res[15] + assert "3.60555127546" == res[17] + assert "10" == res[19] + assert ["RediSearch", "RedisAI", "RedisJson"] == res[21] + assert "RediSearch" == res[23] + assert 2 == len(res[25]) + + @skip_ifmodversion_lt("2.0.0", "search") + def test_index_definition(self, client): + """ + Create definition and test its args + """ + with pytest.raises(RuntimeError): + IndexDefinition(prefix=["hset:", "henry"], index_type="json") + + definition = IndexDefinition( + prefix=["hset:", "henry"], + filter="@f1==32", + language="English", + language_field="play", + score_field="chapter", + score=0.5, + payload_field="txt", + index_type=IndexType.JSON, + ) -@pytest.mark.redismod -def test_synupdate(client): - definition = IndexDefinition(index_type=IndexType.HASH) - client.ft().create_index( - ( - TextField("title"), - TextField("body"), - ), - definition=definition, - ) + assert [ + "ON", + "JSON", + "PREFIX", + 2, + "hset:", + "henry", + "FILTER", + "@f1==32", + "LANGUAGE_FIELD", + "play", + "LANGUAGE", + "English", + "SCORE_FIELD", + "chapter", + "SCORE", + 0.5, + "PAYLOAD_FIELD", + "txt", + ] == definition.args + + createIndex(client.ft(), num_docs=500, definition=definition) + + @skip_ifmodversion_lt("2.0.0", "search") + def test_create_client_definition(self, client): + """ + Create definition with no index type provided, + and use hset to test the client definition (the default is HASH). + """ + definition = IndexDefinition(prefix=["hset:", "henry"]) + createIndex(client.ft(), num_docs=500, definition=definition) + + info = client.ft().info() + assert 494 == int(info["num_docs"]) + + client.ft().client.hset("hset:1", "f1", "v1") + info = client.ft().info() + assert 495 == int(info["num_docs"]) + + @skip_ifmodversion_lt("2.0.0", "search") + def test_create_client_definition_hash(self, client): + """ + Create definition with IndexType.HASH as index type (ON HASH), + and use hset to test the client definition. + """ + definition = IndexDefinition( + prefix=["hset:", "henry"], + index_type=IndexType.HASH + ) + createIndex(client.ft(), num_docs=500, definition=definition) + + info = client.ft().info() + assert 494 == int(info["num_docs"]) + + client.ft().client.hset("hset:1", "f1", "v1") + info = client.ft().info() + assert 495 == int(info["num_docs"]) + + @skip_ifmodversion_lt("2.2.0", "search") + def test_create_client_definition_json(self, client): + """ + Create definition with IndexType.JSON as index type (ON JSON), + and use json client to test it. + """ + definition = IndexDefinition(prefix=["king:"], + index_type=IndexType.JSON) + client.ft().create_index((TextField("$.name"),), definition=definition) + + client.json().set("king:1", Path.rootPath(), {"name": "henry"}) + client.json().set("king:2", Path.rootPath(), {"name": "james"}) + + res = client.ft().search("henry") + assert res.docs[0].id == "king:1" + assert res.docs[0].payload is None + assert res.docs[0].json == '{"name":"henry"}' + assert res.total == 1 + + @skip_ifmodversion_lt("2.2.0", "search") + def test_fields_as_name(self, client): + # create index + SCHEMA = ( + TextField("$.name", sortable=True, as_name="name"), + NumericField("$.age", as_name="just_a_number"), + ) + definition = IndexDefinition(index_type=IndexType.JSON) + client.ft().create_index(SCHEMA, definition=definition) + + # insert json data + res = client.json().set( + "doc:1", + Path.rootPath(), + {"name": "Jon", "age": 25} + ) + assert res + + total = client.ft().search( + Query("Jon").return_fields("name", "just_a_number")).docs + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "Jon" == total[0].name + assert "25" == total[0].just_a_number + + @skip_ifmodversion_lt("2.2.0", "search") + def test_search_return_fields(self, client): + res = client.json().set( + "doc:1", + Path.rootPath(), + {"t": "riceratops", "t2": "telmatosaurus", "n": 9072, "flt": 97.2}, + ) + assert res - client.ft().synupdate("id1", True, "boy", "child", "offspring") - client.ft().add_document( - "doc1", - title="he is a baby", - body="this is a test") - - client.ft().synupdate("id1", True, "baby") - client.ft().add_document( - "doc2", - title="he is another baby", - body="another test" - ) + # create index on + definition = IndexDefinition(index_type=IndexType.JSON) + SCHEMA = ( + TextField("$.t"), + NumericField("$.flt"), + ) + client.ft().create_index(SCHEMA, definition=definition) + waitForIndex(client, "idx") + + total = client.ft().search( + Query("*").return_field("$.t", as_field="txt")).docs + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "riceratops" == total[0].txt + + total = client.ft().search( + Query("*").return_field("$.t2", as_field="txt")).docs + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "telmatosaurus" == total[0].txt + + def test_synupdate(self, client): + definition = IndexDefinition(index_type=IndexType.HASH) + client.ft().create_index( + ( + TextField("title"), + TextField("body"), + ), + definition=definition, + ) - res = client.ft().search(Query("child").expander("SYNONYM")) - assert res.docs[0].id == "doc2" - assert res.docs[0].title == "he is another baby" - assert res.docs[0].body == "another test" + client.ft().synupdate("id1", True, "boy", "child", "offspring") + client.ft().add_document( + "doc1", + title="he is a baby", + body="this is a test") + client.ft().synupdate("id1", True, "baby") + client.ft().add_document( + "doc2", + title="he is another baby", + body="another test" + ) -@pytest.mark.redismod -def test_syndump(client): - definition = IndexDefinition(index_type=IndexType.HASH) - client.ft().create_index( - ( - TextField("title"), - TextField("body"), - ), - definition=definition, - ) + res = client.ft().search(Query("child").expander("SYNONYM")) + assert res.docs[0].id == "doc2" + assert res.docs[0].title == "he is another baby" + assert res.docs[0].body == "another test" + + def test_syndump(self, client): + definition = IndexDefinition(index_type=IndexType.HASH) + client.ft().create_index( + ( + TextField("title"), + TextField("body"), + ), + definition=definition, + ) - client.ft().synupdate("id1", False, "boy", "child", "offspring") - client.ft().synupdate("id2", False, "baby", "child") - client.ft().synupdate("id3", False, "tree", "wood") - res = client.ft().syndump() - assert res == { - "boy": ["id1"], - "tree": ["id3"], - "wood": ["id3"], - "child": ["id1", "id2"], - "baby": ["id2"], - "offspring": ["id1"], - } + client.ft().synupdate("id1", False, "boy", "child", "offspring") + client.ft().synupdate("id2", False, "baby", "child") + client.ft().synupdate("id3", False, "tree", "wood") + res = client.ft().syndump() + assert res == { + "boy": ["id1"], + "tree": ["id3"], + "wood": ["id3"], + "child": ["id1", "id2"], + "baby": ["id2"], + "offspring": ["id1"], + } diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 7f66603085..1d55ea1215 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -5,7 +5,6 @@ from redis import exceptions from redis.sentinel import (Sentinel, SentinelConnectionPool, MasterNotFoundError, SlaveNotFoundError) -from .conftest import skip_if_cluster_mode import redis.sentinel @@ -14,7 +13,6 @@ def master_ip(master_host): yield socket.gethostbyname(master_host) -@skip_if_cluster_mode() class SentinelTestClient: def __init__(self, cluster, id): self.cluster = cluster @@ -55,7 +53,6 @@ def sentinel(request, cluster): return Sentinel([('foo', 26379), ('bar', 26379)]) -@skip_if_cluster_mode() class SentinelTestCluster: def __init__(self, servisentinel_ce_name='mymaster', ip='127.0.0.1', port=6379): diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index e066a7f385..941f2f9c5f 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -1,7 +1,7 @@ import pytest import time from time import sleep -from .conftest import skip_ifmodversion_lt, skip_if_cluster_mode +from .conftest import skip_ifmodversion_lt @pytest.fixture @@ -10,7 +10,7 @@ def client(modclient): return modclient -@skip_if_cluster_mode() +@pytest.mark.redismod class TestTimeseries: @pytest.mark.redismod def testCreate(self, client): diff --git a/tox.ini b/tox.ini index ad6be36723..1ad9c0a9aa 100644 --- a/tox.ini +++ b/tox.ini @@ -2,6 +2,8 @@ addopts = -s markers = redismod: run only the redis module tests + onlycluster: marks tests to be run only with cluster mode redis + onlynoncluster: marks tests to be run only with non-cluster redis [tox] minversion = 3.2.0 @@ -104,7 +106,20 @@ docker = extras = hiredis: hiredis commands = - pytest --cov=./ --cov-report=xml -W always {posargs} + pytest --cov=./ --cov-report=xml -W always -m 'not onlycluster and not redismod' {posargs} + pytest --cov=./ --cov-report=xml -W always -m 'not onlynoncluster and not redismod' --redis-url=redis://localhost:16379/0 {posargs} + + +[testenv:cluster] +deps = + {[testenv]deps} +docker = + redis_cluster +extras = + hiredis: hiredis +commands = + pytest --cov=./ --cov-report=xml -W always -m 'not onlynoncluster and not redismod' --redis-url=redis://localhost:16379/0 {posargs} + [testenv:devenv] skipsdist = true From f4cc29f059a66ff201730f964e02a6dfe6ce898b Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Thu, 4 Nov 2021 12:03:24 +0200 Subject: [PATCH 08/22] Resolving PR comments --- docker/cluster/redis.conf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docker/cluster/redis.conf b/docker/cluster/redis.conf index cc22e16ffe..3103806dee 100644 --- a/docker/cluster/redis.conf +++ b/docker/cluster/redis.conf @@ -1,3 +1,3 @@ # Redis Cluster config file will be shared across all nodes. -# Dont pass node-unique arguments (e.g. port, dir). +# Don't pass node-unique arguments (e.g. port, dir). cluster-enabled yes From d7fbec10ddb12a58e6e70ff9eb64b36668ad5f57 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Thu, 4 Nov 2021 17:59:12 +0200 Subject: [PATCH 09/22] Resolving PR comments --- docker/base/create_cluster.sh | 3 ++- docker/cluster/redis.conf | 4 ++-- redis/client.py | 6 ++--- redis/cluster.py | 1 + redis/commands/__init__.py | 12 +++++----- redis/commands/parser.py | 13 +++++------ redis/connection.py | 5 +++++ redis/crc.py | 8 ++----- redis/exceptions.py | 42 +++++++++++++++++++++++++++++++++-- tests/conftest.py | 37 +++++++++++++++++++++--------- tox.ini | 6 +++-- 11 files changed, 97 insertions(+), 40 deletions(-) diff --git a/docker/base/create_cluster.sh b/docker/base/create_cluster.sh index 0490842670..28aa3b1b8d 100644 --- a/docker/base/create_cluster.sh +++ b/docker/base/create_cluster.sh @@ -1,6 +1,6 @@ #! /bin/bash mkdir -p /nodes -echo -n > /nodes/nodemap +touch /nodes/nodemap for PORT in $(seq 16379 16384); do mkdir -p /nodes/$PORT if [[ -e /redis.conf ]]; then @@ -10,6 +10,7 @@ for PORT in $(seq 16379 16384); do fi cat << EOF >> /nodes/$PORT/redis.conf port $PORT +cluster-enabled yes daemonize yes logfile /redis.log dir /nodes/$PORT diff --git a/docker/cluster/redis.conf b/docker/cluster/redis.conf index 3103806dee..dff658c79b 100644 --- a/docker/cluster/redis.conf +++ b/docker/cluster/redis.conf @@ -1,3 +1,3 @@ # Redis Cluster config file will be shared across all nodes. -# Don't pass node-unique arguments (e.g. port, dir). -cluster-enabled yes +# Do not change the following configurations that are already set: +# port, cluster-enabled, daemonize, logfile, dir diff --git a/redis/client.py b/redis/client.py index 3768c2e5e1..438b8781ca 100755 --- a/redis/client.py +++ b/redis/client.py @@ -477,8 +477,8 @@ def _parse_node_line(line): def parse_cluster_nodes(response, **options): """ - @see: http://redis.io/commands/cluster-nodes # string - @see: http://redis.io/commands/cluster-replicas # list of string + @see: https://redis.io/commands/cluster-nodes # string + @see: https://redis.io/commands/cluster-replicas # list of string """ if isinstance(response, str): response = response.splitlines() @@ -527,7 +527,7 @@ def parse_command(response, **options): cmd_dict = {} cmd_name = str_if_bytes(command[0]) cmd_dict['name'] = cmd_name - cmd_dict['arity'] = str_if_bytes(command[1]) + cmd_dict['arity'] = int(command[1]) cmd_dict['flags'] = [str_if_bytes(flag) for flag in command[2]] cmd_dict['first_key_pos'] = command[3] cmd_dict['last_key_pos'] = command[4] diff --git a/redis/cluster.py b/redis/cluster.py index e3976dcb03..bb2dbba467 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -641,6 +641,7 @@ def _should_reinitialized(self): def keyslot(self, key): """ Calculate keyslot for a given key. + See Keys distribution model in https://redis.io/topics/cluster-spec """ k = self.encoder.encode(key) return key_slot(k) diff --git a/redis/commands/__init__.py b/redis/commands/__init__.py index 60f13d8d35..a4728d0ac4 100644 --- a/redis/commands/__init__.py +++ b/redis/commands/__init__.py @@ -1,15 +1,15 @@ +from .cluster import ClusterCommands from .core import CoreCommands -from .redismodules import RedisModuleCommands from .helpers import list_or_args -from .sentinel import SentinelCommands -from .cluster import ClusterCommands from .parser import CommandsParser +from .redismodules import RedisModuleCommands +from .sentinel import SentinelCommands __all__ = [ - 'CoreCommands', 'ClusterCommands', 'CommandsParser', + 'CoreCommands', + 'list_or_args', 'RedisModuleCommands', - 'SentinelCommands', - 'list_or_args' + 'SentinelCommands' ] diff --git a/redis/commands/parser.py b/redis/commands/parser.py index 22478ed2ed..7a8004a913 100644 --- a/redis/commands/parser.py +++ b/redis/commands/parser.py @@ -66,14 +66,13 @@ def get_keys(self, redis_conn, *args): return keys def _get_moveable_keys(self, redis_conn, *args): + pieces = [] + cmd_name = args[0] + # The command name should be splitted into separate arguments, + # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] + pieces = pieces + cmd_name.split() + pieces = pieces + list(args[1:]) try: - pieces = [] - cmd_name = args[0] - for arg in cmd_name.split(): - # The command name should be splitted into separate arguments, - # e.g. 'MEMORY USAGE' will be splitted into ['MEMORY', 'USAGE'] - pieces.append(arg) - pieces += args[1:] keys = redis_conn.execute_command('COMMAND GETKEYS', *pieces) except ResponseError as e: message = e.__str__() diff --git a/redis/connection.py b/redis/connection.py index f2becbeba7..039b6a1437 100755 --- a/redis/connection.py +++ b/redis/connection.py @@ -571,6 +571,11 @@ def clear_connect_callbacks(self): self._connect_callbacks = [] def set_parser(self, parser_class): + """ + Creates a new instance of parser_class with socket size: + _socket_read_size and assigns it to the parser for the connection + :param parser_class: The required parser class + """ self._parser = parser_class(socket_read_size=self._socket_read_size) def connect(self): diff --git a/redis/crc.py b/redis/crc.py index a4dfdf69f5..7d2ee507be 100644 --- a/redis/crc.py +++ b/redis/crc.py @@ -5,18 +5,14 @@ REDIS_CLUSTER_HASH_SLOTS = 16384 __all__ = [ - "crc16", "key_slot", "REDIS_CLUSTER_HASH_SLOTS" ] -def crc16(data): - return crc_hqx(data, 0) - - def key_slot(key, bucket=REDIS_CLUSTER_HASH_SLOTS): """Calculate key slot for a given key. + See Keys distribution model in https://redis.io/topics/cluster-spec :param key - bytes :param bucket - int """ @@ -25,4 +21,4 @@ def key_slot(key, bucket=REDIS_CLUSTER_HASH_SLOTS): end = key.find(b"}", start + 1) if end > -1 and end != start + 1: key = key[start + 1: end] - return crc16(key) % bucket + return crc_hqx(key, 0) % bucket diff --git a/redis/exceptions.py b/redis/exceptions.py index 5ea7fe9c30..ac88d03bd1 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -87,15 +87,30 @@ class AuthenticationWrongNumberOfArgsError(ResponseError): class RedisClusterException(Exception): + """ + Base exception for the RedisCluster client + """ pass class ClusterError(RedisError): + """ + Cluster errors occurred multiple times, resulting in an exhaustion of the + command execution TTL + """ pass class ClusterDownError(ClusterError, ResponseError): - + """ + Error indicated CLUSTERDOWN error received from cluster. + By default Redis Cluster nodes stop accepting queries if they detect there + is at least an hash slot uncovered (no available node is serving it). + This way if the cluster is partially down (for example a range of hash + slots are no longer covered) all the cluster becomes, eventually, + unavailable. It automatically returns available as soon as all the slots + are covered again. + """ def __init__(self, resp): self.args = (resp,) self.message = resp @@ -103,6 +118,11 @@ def __init__(self, resp): class AskError(ResponseError): """ + Error indicated ASK error received from cluster. + When a slot is set as MIGRATING, the node will accept all queries that are + about this hash slot, but only if the key in question exists, otherwise the + query is forwarded using a -ASK redirection to the node that is target of + the migration. src node: MIGRATING to dst node get > ASK error ask dst node > ASKING command @@ -122,20 +142,38 @@ def __init__(self, resp): class TryAgainError(ResponseError): - + """ + Error indicated TRYAGAIN error received from cluster. + Operations on keys that don't exist or are - during resharding - split + between the source and destination nodes, will generate a -TRYAGAIN error. + """ def __init__(self, *args, **kwargs): pass class ClusterCrossSlotError(ResponseError): + """ + Error indicated CROSSSLOT error received from cluster. + A CROSSSLOT error is generated when keys in a request don't hash to the + same slot. + """ message = "Keys in request don't hash to the same slot" class MovedError(AskError): + """ + Error indicated MOVED error received from cluster. + A request sent to a node that doesn't serve this key will be replayed with + a MOVED error that points to the correct node. + """ pass class MasterDownError(ClusterDownError): + """ + Error indicated MASTERDOWN error received from cluster. + Link with MASTER is down and replica-serve-stale-data is set to 'no'. + """ pass diff --git a/tests/conftest.py b/tests/conftest.py index 35d2f6cbde..c06c3a856d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,7 +30,7 @@ def pytest_addoption(parser): " with loaded modules," " defaults to `%(default)s`") - parser.addoption('--cluster-nodes', default=default_cluster_nodes, + parser.addoption('--redis-cluster-nodes', default=default_cluster_nodes, action="store", help="The number of cluster nodes that need to be " "available before the test can start," @@ -62,16 +62,25 @@ def pytest_sessionstart(session): REDIS_INFO["modules"] = info["modules"] if cluster_enabled: - cluster_nodes = session.config.getoption("--cluster-nodes") + cluster_nodes = session.config.getoption("--redis-cluster-nodes") wait_for_cluster_creation(redis_url, cluster_nodes) def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=20): + """ + Waits for the cluster creation to complete. + As soon as all :cluster_nodes: nodes become available, the cluster will be + considered ready. + :param redis_url: the cluster's url, e.g. redis://localhost:16379/0 + :param cluster_nodes: The number of nodes in the cluster + :param timeout: the amount of time to wait (in seconds) + """ now = time.time() - timeout = now + timeout + end_time = now + timeout + client = None print("Waiting for {0} cluster nodes to become available". format(cluster_nodes)) - while now < timeout: + while now < end_time: try: client = redis.RedisCluster.from_url(redis_url) if len(client.get_nodes()) == cluster_nodes: @@ -81,6 +90,12 @@ def wait_for_cluster_creation(redis_url, cluster_nodes, timeout=20): pass time.sleep(1) now = time.time() + if now >= end_time: + available_nodes = 0 if client is None else len(client.get_nodes()) + raise RedisClusterException( + "The cluster did not become available after {0} seconds. " + "Only {1} nodes out of {2} are available".format( + timeout, available_nodes, cluster_nodes)) def skip_if_server_version_lt(min_version): @@ -133,14 +148,14 @@ def _get_client(cls, request, single_connection_client=True, flushdb=True, redis_url = request.config.getoption("--redis-url") else: redis_url = from_url - if REDIS_INFO["cluster_enabled"]: - client = redis.RedisCluster.from_url(redis_url, **kwargs) - single_connection_client = False - else: + if not REDIS_INFO["cluster_enabled"]: url_options = parse_url(redis_url) url_options.update(kwargs) pool = redis.ConnectionPool(**url_options) client = cls(connection_pool=pool) + else: + client = redis.RedisCluster.from_url(redis_url, **kwargs) + single_connection_client = False if single_connection_client: client = client.client() if request: @@ -153,10 +168,10 @@ def teardown(): # just manually retry the flushdb client.flushdb() client.close() - if REDIS_INFO["cluster_enabled"]: - client.disconnect_connection_pools() - else: + if not REDIS_INFO["cluster_enabled"]: client.connection_pool.disconnect() + else: + client.disconnect_connection_pools() request.addfinalizer(teardown) return client diff --git a/tox.ini b/tox.ini index 51bab44f23..4645f5147d 100644 --- a/tox.ini +++ b/tox.ini @@ -79,7 +79,6 @@ volumes = [docker:redis_cluster] name = redis_cluster image = barshaul/redis-py:6.2.6-cluster -healtcheck_cmd = python -c "import socket;print(True) if 0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',16379)) else False" ports = 16379:16379/tcp 16380:16380/tcp @@ -87,6 +86,7 @@ ports = 16382:16382/tcp 16383:16383/tcp 16384:16384/tcp +healtcheck_cmd = python -c "import socket;print(True) if all([0 == socket.socket(socket.AF_INET, socket.SOCK_STREAM).connect_ex(('127.0.0.1',port)) for port in range(16379,16384)]) else False" volumes = bind:rw:{toxinidir}/docker/cluster/redis.conf:/redis.conf @@ -105,9 +105,11 @@ docker = redismod extras = hiredis: hiredis +setenv = + CLUSTER_URL = "redis://localhost:16379/0" commands = redis: pytest --cov=./ --cov-report=xml -W always -m 'not onlycluster and not redismod' {posargs} - cluster: pytest --cov=./ --cov-report=xml -W always -m 'not onlynoncluster and not redismod' --redis-url=redis://localhost:16379/0 {posargs} + cluster: pytest --cov=./ --cov-report=xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} {posargs} [testenv:devenv] skipsdist = true From cd056e370addc9cd98f33f2dacefb241b7ebf748 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Sun, 7 Nov 2021 18:16:23 +0200 Subject: [PATCH 10/22] Merging from redis:master --- tests/test_json.py | 3 + tests/test_scripting.py | 1 + tests/test_search.py | 2046 +++++++++++++++++++------------------- tests/test_sentinel.py | 1 + tests/test_timeseries.py | 1079 ++++++++++---------- 5 files changed, 1595 insertions(+), 1535 deletions(-) diff --git a/tests/test_json.py b/tests/test_json.py index 19b0c3262e..00fa571561 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -991,6 +991,7 @@ def test_debug_dollar(client): assert client.json().debug("MEMORY", "non_existing_doc", "$..a") == [] +@pytest.mark.redismod def test_resp_dollar(client): data = { @@ -1137,6 +1138,7 @@ def test_resp_dollar(client): assert client.json().resp("non_existing_doc", "$..a") is None +@pytest.mark.redismod def test_arrindex_dollar(client): client.json().set( @@ -1377,6 +1379,7 @@ def test_arrindex_dollar(client): "None") == 0 +@pytest.mark.redismod def test_decoders_and_unstring(): assert unstring("4") == 4 assert unstring("45.55") == 45.55 diff --git a/tests/test_scripting.py b/tests/test_scripting.py index 25abee823c..e7c559d833 100644 --- a/tests/test_scripting.py +++ b/tests/test_scripting.py @@ -2,6 +2,7 @@ from redis import exceptions + multiply_script = """ local value = redis.call('GET', KEYS[1]) value = tonumber(value) diff --git a/tests/test_search.py b/tests/test_search.py index 1cb04c820c..926b5ff3af 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -82,8 +82,8 @@ def createIndex(client, num_docs=100, definition=None): try: client.create_index( (TextField("play", weight=5.0), - TextField("txt"), - NumericField("chapter")), + TextField("txt"), + NumericField("chapter")), definition=definition, ) except redis.ResponseError: @@ -129,1049 +129,1091 @@ def client(modclient): @pytest.mark.redismod -class TestSearch: - def test_client(self, client): - num_docs = 500 - createIndex(client.ft(), num_docs=num_docs) - waitForIndex(client, "idx") - # verify info - info = client.ft().info() - for k in [ - "index_name", - "index_options", - "attributes", - "num_docs", - "max_doc_id", - "num_terms", - "num_records", - "inverted_sz_mb", - "offset_vectors_sz_mb", - "doc_table_size_mb", - "key_table_size_mb", - "records_per_doc_avg", - "bytes_per_record_avg", - "offsets_per_term_avg", - "offset_bits_per_record_avg", - ]: - assert k in info - - assert client.ft().index_name == info["index_name"] - assert num_docs == int(info["num_docs"]) - - res = client.ft().search("henry iv") - assert isinstance(res, Result) - assert 225 == res.total - assert 10 == len(res.docs) - assert res.duration > 0 - - for doc in res.docs: - assert doc.id - assert doc.play == "Henry IV" - assert len(doc.txt) > 0 - - # test no content - res = client.ft().search(Query("king").no_content()) - assert 194 == res.total - assert 10 == len(res.docs) - for doc in res.docs: - assert "txt" not in doc.__dict__ - assert "play" not in doc.__dict__ - - # test verbatim vs no verbatim - total = client.ft().search(Query("kings").no_content()).total - vtotal = client.ft().search( - Query("kings").no_content().verbatim()).total - assert total > vtotal - - # test in fields - txt_total = ( - client.ft().search( - Query("henry").no_content().limit_fields("txt")).total - ) - play_total = ( - client.ft().search( - Query("henry").no_content().limit_fields("play")).total - ) - both_total = ( - client.ft() - .search( - Query("henry").no_content().limit_fields("play", "txt")) - .total - ) - assert 129 == txt_total - assert 494 == play_total - assert 494 == both_total - - # test load_document - doc = client.ft().load_document("henry vi part 3:62") - assert doc is not None - assert "henry vi part 3:62" == doc.id - assert doc.play == "Henry VI Part 3" +def test_client(client): + num_docs = 500 + createIndex(client.ft(), num_docs=num_docs) + waitForIndex(client, "idx") + # verify info + info = client.ft().info() + for k in [ + "index_name", + "index_options", + "attributes", + "num_docs", + "max_doc_id", + "num_terms", + "num_records", + "inverted_sz_mb", + "offset_vectors_sz_mb", + "doc_table_size_mb", + "key_table_size_mb", + "records_per_doc_avg", + "bytes_per_record_avg", + "offsets_per_term_avg", + "offset_bits_per_record_avg", + ]: + assert k in info + + assert client.ft().index_name == info["index_name"] + assert num_docs == int(info["num_docs"]) + + res = client.ft().search("henry iv") + assert isinstance(res, Result) + assert 225 == res.total + assert 10 == len(res.docs) + assert res.duration > 0 + + for doc in res.docs: + assert doc.id + assert doc.play == "Henry IV" assert len(doc.txt) > 0 - # test in-keys - ids = [x.id for x in client.ft().search(Query("henry")).docs] - assert 10 == len(ids) - subset = ids[:5] - docs = client.ft().search(Query("henry").limit_ids(*subset)) - assert len(subset) == docs.total - ids = [x.id for x in docs.docs] - assert set(ids) == set(subset) - - # test slop and in order - assert 193 == client.ft().search(Query("henry king")).total - assert 3 == client.ft().search( - Query("henry king").slop(0).in_order()).total - assert 52 == client.ft().search( - Query("king henry").slop(0).in_order()).total - assert 53 == client.ft().search(Query("henry king").slop(0)).total - assert 167 == client.ft().search(Query("henry king").slop(100)).total - - # test delete document - client.ft().add_document("doc-5ghs2", play="Death of a Salesman") - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - - assert 1 == client.ft().delete_document("doc-5ghs2") - res = client.ft().search(Query("death of a salesman")) - assert 0 == res.total - assert 0 == client.ft().delete_document("doc-5ghs2") - - client.ft().add_document("doc-5ghs2", play="Death of a Salesman") - res = client.ft().search(Query("death of a salesman")) - assert 1 == res.total - client.ft().delete_document("doc-5ghs2") - - @skip_ifmodversion_lt("2.2.0", "search") - def test_payloads(self, client): - client.ft().create_index((TextField("txt"),)) - - client.ft().add_document("doc1", payload="foo baz", txt="foo bar") - client.ft().add_document("doc2", txt="foo bar") - - q = Query("foo bar").with_payloads() - res = client.ft().search(q) - assert 2 == res.total - assert "doc1" == res.docs[0].id - assert "doc2" == res.docs[1].id - assert "foo baz" == res.docs[0].payload - assert res.docs[1].payload is None - - def test_scores(self, client): - client.ft().create_index((TextField("txt"),)) - - client.ft().add_document("doc1", txt="foo baz") - client.ft().add_document("doc2", txt="foo bar") - - q = Query("foo ~bar").with_scores() - res = client.ft().search(q) - assert 2 == res.total - assert "doc2" == res.docs[0].id - assert 3.0 == res.docs[0].score - assert "doc1" == res.docs[1].id - # todo: enable once new RS version is tagged - # self.assertEqual(0.2, res.docs[1].score) - - def test_replace(self, client): - client.ft().create_index((TextField("txt"),)) - - client.ft().add_document("doc1", txt="foo bar") - client.ft().add_document("doc2", txt="foo bar") - waitForIndex(client, "idx") - - res = client.ft().search("foo bar") - assert 2 == res.total - client.ft().add_document( - "doc1", - replace=True, - txt="this is a replaced doc" - ) + # test no content + res = client.ft().search(Query("king").no_content()) + assert 194 == res.total + assert 10 == len(res.docs) + for doc in res.docs: + assert "txt" not in doc.__dict__ + assert "play" not in doc.__dict__ + + # test verbatim vs no verbatim + total = client.ft().search(Query("kings").no_content()).total + vtotal = client.ft().search(Query("kings").no_content().verbatim()).total + assert total > vtotal + + # test in fields + txt_total = ( + client.ft().search( + Query("henry").no_content().limit_fields("txt")).total + ) + play_total = ( + client.ft().search( + Query("henry").no_content().limit_fields("play")).total + ) + both_total = ( + client.ft() + .search(Query("henry").no_content().limit_fields("play", "txt")) + .total + ) + assert 129 == txt_total + assert 494 == play_total + assert 494 == both_total + + # test load_document + doc = client.ft().load_document("henry vi part 3:62") + assert doc is not None + assert "henry vi part 3:62" == doc.id + assert doc.play == "Henry VI Part 3" + assert len(doc.txt) > 0 + + # test in-keys + ids = [x.id for x in client.ft().search(Query("henry")).docs] + assert 10 == len(ids) + subset = ids[:5] + docs = client.ft().search(Query("henry").limit_ids(*subset)) + assert len(subset) == docs.total + ids = [x.id for x in docs.docs] + assert set(ids) == set(subset) + + # test slop and in order + assert 193 == client.ft().search(Query("henry king")).total + assert 3 == client.ft().search( + Query("henry king").slop(0).in_order()).total + assert 52 == client.ft().search( + Query("king henry").slop(0).in_order()).total + assert 53 == client.ft().search(Query("henry king").slop(0)).total + assert 167 == client.ft().search(Query("henry king").slop(100)).total + + # test delete document + client.ft().add_document("doc-5ghs2", play="Death of a Salesman") + res = client.ft().search(Query("death of a salesman")) + assert 1 == res.total + + assert 1 == client.ft().delete_document("doc-5ghs2") + res = client.ft().search(Query("death of a salesman")) + assert 0 == res.total + assert 0 == client.ft().delete_document("doc-5ghs2") + + client.ft().add_document("doc-5ghs2", play="Death of a Salesman") + res = client.ft().search(Query("death of a salesman")) + assert 1 == res.total + client.ft().delete_document("doc-5ghs2") - res = client.ft().search("foo bar") - assert 1 == res.total - assert "doc2" == res.docs[0].id - res = client.ft().search("replaced doc") - assert 1 == res.total - assert "doc1" == res.docs[0].id +@pytest.mark.redismod +@skip_ifmodversion_lt("2.2.0", "search") +def test_payloads(client): + client.ft().create_index((TextField("txt"),)) - def test_stopwords(self, client): - client.ft().create_index( - (TextField("txt"),), - stopwords=["foo", "bar", "baz"] - ) - client.ft().add_document("doc1", txt="foo bar") - client.ft().add_document("doc2", txt="hello world") - waitForIndex(client, "idx") - - q1 = Query("foo bar").no_content() - q2 = Query("foo bar hello world").no_content() - res1, res2 = client.ft().search(q1), client.ft().search(q2) - assert 0 == res1.total - assert 1 == res2.total - - def test_filters(self, client): - client.ft().create_index( - (TextField("txt"), - NumericField("num"), - GeoField("loc")) - ) - client.ft().add_document( - "doc1", - txt="foo bar", - num=3.141, - loc="-0.441,51.458" - ) - client.ft().add_document("doc2", txt="foo baz", num=2, loc="-0.1,51.2") - - waitForIndex(client, "idx") - # Test numerical filter - q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() - q2 = ( - Query("foo") - .add_filter( - NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) - .no_content() - ) - res1, res2 = client.ft().search(q1), client.ft().search(q2) - - assert 1 == res1.total - assert 1 == res2.total - assert "doc2" == res1.docs[0].id - assert "doc1" == res2.docs[0].id - - # Test geo filter - q1 = Query("foo").add_filter( - GeoFilter("loc", -0.44, 51.45, 10)).no_content() - q2 = Query("foo").add_filter( - GeoFilter("loc", -0.44, 51.45, 100)).no_content() - res1, res2 = client.ft().search(q1), client.ft().search(q2) - - assert 1 == res1.total - assert 2 == res2.total - assert "doc1" == res1.docs[0].id - - # Sort results, after RDB reload order may change - res = [res2.docs[0].id, res2.docs[1].id] - res.sort() - assert ["doc1", "doc2"] == res - - def test_payloads_with_no_content(self, client): - client.ft().create_index((TextField("txt"),)) - client.ft().add_document("doc1", payload="foo baz", txt="foo bar") - client.ft().add_document("doc2", payload="foo baz2", txt="foo bar") - - q = Query("foo bar").with_payloads().no_content() - res = client.ft().search(q) - assert 2 == len(res.docs) - - def test_sort_by(self, client): - client.ft().create_index( - (TextField("txt"), - NumericField("num", sortable=True)) - ) - client.ft().add_document("doc1", txt="foo bar", num=1) - client.ft().add_document("doc2", txt="foo baz", num=2) - client.ft().add_document("doc3", txt="foo qux", num=3) - - # Test sort - q1 = Query("foo").sort_by("num", asc=True).no_content() - q2 = Query("foo").sort_by("num", asc=False).no_content() - res1, res2 = client.ft().search(q1), client.ft().search(q2) - - assert 3 == res1.total - assert "doc1" == res1.docs[0].id - assert "doc2" == res1.docs[1].id - assert "doc3" == res1.docs[2].id - assert 3 == res2.total - assert "doc1" == res2.docs[2].id - assert "doc2" == res2.docs[1].id - assert "doc3" == res2.docs[0].id - - @skip_ifmodversion_lt("2.0.0", "search") - def test_drop_index(self): - """ - Ensure the index gets dropped by data remains by default - """ - for x in range(20): - for keep_docs in [[True, {}], [False, {"name": "haveit"}]]: - idx = "HaveIt" - index = getClient() - index.hset("index:haveit", mapping={"name": "haveit"}) - idef = IndexDefinition(prefix=["index:"]) - index.ft(idx).create_index((TextField("name"),), - definition=idef) - waitForIndex(index, idx) - index.ft(idx).dropindex(delete_documents=keep_docs[0]) - i = index.hgetall("index:haveit") - assert i == keep_docs[1] - - def test_example(self, client): - # Creating the index definition and schema - client.ft().create_index( - (TextField("title", weight=5.0), - TextField("body")) - ) + client.ft().add_document("doc1", payload="foo baz", txt="foo bar") + client.ft().add_document("doc2", txt="foo bar") - # Indexing a document - client.ft().add_document( - "doc1", - title="RediSearch", - body="Redisearch impements a search engine on top of redis", - ) + q = Query("foo bar").with_payloads() + res = client.ft().search(q) + assert 2 == res.total + assert "doc1" == res.docs[0].id + assert "doc2" == res.docs[1].id + assert "foo baz" == res.docs[0].payload + assert res.docs[1].payload is None - # Searching with complex parameters: - q = Query("search engine").verbatim().no_content().paging(0, 5) - - res = client.ft().search(q) - assert res is not None - - def test_auto_complete(self, client): - n = 0 - with open(TITLES_CSV) as f: - cr = csv.reader(f) - - for row in cr: - n += 1 - term, score = row[0], float(row[1]) - assert n == client.ft().sugadd("ac", - Suggestion(term, score=score)) - - assert n == client.ft().suglen("ac") - ret = client.ft().sugget("ac", "bad", with_scores=True) - assert 2 == len(ret) - assert "badger" == ret[0].string - assert isinstance(ret[0].score, float) - assert 1.0 != ret[0].score - assert "badalte rishtey" == ret[1].string - assert isinstance(ret[1].score, float) - assert 1.0 != ret[1].score - - ret = client.ft().sugget("ac", "bad", fuzzy=True, num=10) - assert 10 == len(ret) - assert 1.0 == ret[0].score - strs = {x.string for x in ret} - - for sug in strs: - assert 1 == client.ft().sugdel("ac", sug) - # make sure a second delete returns 0 - for sug in strs: - assert 0 == client.ft().sugdel("ac", sug) - - # make sure they were actually deleted - ret2 = client.ft().sugget("ac", "bad", fuzzy=True, num=10) - for sug in ret2: - assert sug.string not in strs - - # Test with payload - client.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) - client.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) - client.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) - - sugs = client.ft().sugget( - "ac", - "pay", - with_payloads=True, - with_scores=True - ) - assert 3 == len(sugs) - for sug in sugs: - assert sug.payload - assert sug.payload.startswith("pl") - - def test_no_index(self, client): - client.ft().create_index( - ( - TextField("field"), - TextField("text", no_index=True, sortable=True), - NumericField("numeric", no_index=True, sortable=True), - GeoField("geo", no_index=True, sortable=True), - TagField("tag", no_index=True, sortable=True), - ) - ) - client.ft().add_document( - "doc1", field="aaa", text="1", numeric="1", geo="1,1", tag="1" - ) - client.ft().add_document( - "doc2", field="aab", text="2", numeric="2", geo="2,2", tag="2" - ) - waitForIndex(client, "idx") +@pytest.mark.redismod +def test_scores(client): + client.ft().create_index((TextField("txt"),)) - res = client.ft().search(Query("@text:aa*")) - assert 0 == res.total + client.ft().add_document("doc1", txt="foo baz") + client.ft().add_document("doc2", txt="foo bar") - res = client.ft().search(Query("@field:aa*")) - assert 2 == res.total + q = Query("foo ~bar").with_scores() + res = client.ft().search(q) + assert 2 == res.total + assert "doc2" == res.docs[0].id + assert 3.0 == res.docs[0].score + assert "doc1" == res.docs[1].id + # todo: enable once new RS version is tagged + # self.assertEqual(0.2, res.docs[1].score) - res = client.ft().search(Query("*").sort_by("text", asc=False)) - assert 2 == res.total - assert "doc2" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("text", asc=True)) - assert "doc1" == res.docs[0].id +@pytest.mark.redismod +def test_replace(client): + client.ft().create_index((TextField("txt"),)) + + client.ft().add_document("doc1", txt="foo bar") + client.ft().add_document("doc2", txt="foo bar") + waitForIndex(client, "idx") + + res = client.ft().search("foo bar") + assert 2 == res.total + client.ft().add_document( + "doc1", + replace=True, + txt="this is a replaced doc" + ) - res = client.ft().search(Query("*").sort_by("numeric", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search("foo bar") + assert 1 == res.total + assert "doc2" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("geo", asc=True)) - assert "doc1" == res.docs[0].id + res = client.ft().search("replaced doc") + assert 1 == res.total + assert "doc1" == res.docs[0].id - res = client.ft().search(Query("*").sort_by("tag", asc=True)) - assert "doc1" == res.docs[0].id - # Ensure exception is raised for non-indexable, non-sortable fields - with pytest.raises(Exception): - TextField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - NumericField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - GeoField("name", no_index=True, sortable=False) - with pytest.raises(Exception): - TagField("name", no_index=True, sortable=False) +@pytest.mark.redismod +def test_stopwords(client): + client.ft().create_index( + (TextField("txt"),), + stopwords=["foo", "bar", "baz"] + ) + client.ft().add_document("doc1", txt="foo bar") + client.ft().add_document("doc2", txt="hello world") + waitForIndex(client, "idx") - def test_partial(self, client): - client.ft().create_index( - (TextField("f1"), - TextField("f2"), - TextField("f3")) - ) - client.ft().add_document("doc1", f1="f1_val", f2="f2_val") - client.ft().add_document("doc2", f1="f1_val", f2="f2_val") - client.ft().add_document("doc1", f3="f3_val", partial=True) - client.ft().add_document("doc2", f3="f3_val", replace=True) - waitForIndex(client, "idx") - - # Search for f3 value. All documents should have it - res = client.ft().search("@f3:f3_val") - assert 2 == res.total - - # Only the document updated with PARTIAL should still have f1 and f2 - # values - res = client.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") - assert 1 == res.total - - def test_no_create(self, client): - client.ft().create_index( - (TextField("f1"), - TextField("f2"), - TextField("f3")) - ) - client.ft().add_document("doc1", f1="f1_val", f2="f2_val") - client.ft().add_document("doc2", f1="f1_val", f2="f2_val") - client.ft().add_document("doc1", f3="f3_val", no_create=True) - client.ft().add_document("doc2", f3="f3_val", no_create=True, - partial=True) - waitForIndex(client, "idx") - - # Search for f3 value. All documents should have it - res = client.ft().search("@f3:f3_val") - assert 2 == res.total - - # Only the document updated with PARTIAL should still have f1 and f2 - # values - res = client.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") - assert 1 == res.total - - with pytest.raises(redis.ResponseError): - client.ft().add_document( - "doc3", - f2="f2_val", - f3="f3_val", - no_create=True - ) - - def test_explain(self, client): - client.ft().create_index( - (TextField("f1"), - TextField("f2"), - TextField("f3")) - ) - res = client.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") - assert res - - def test_summarize(self, client): - createIndex(client.ft()) - waitForIndex(client, "idx") - - q = Query("king henry").paging(0, 1) - q.highlight(fields=("play", "txt"), tags=("", "")) - q.summarize("txt") - - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry IV" == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING " - "HENRY, LORD JOHN OF LANCASTER, the EARL of " - "WESTMORELAND, SIR... " # noqa - == doc.txt - ) + q1 = Query("foo bar").no_content() + q2 = Query("foo bar hello world").no_content() + res1, res2 = client.ft().search(q1), client.ft().search(q2) + assert 0 == res1.total + assert 1 == res2.total - q = Query("king henry").paging(0, 1).summarize().highlight() - doc = sorted(client.ft().search(q).docs)[0] - assert "Henry ... " == doc.play - assert ( - "ACT I SCENE I. London. The palace. Enter KING " - "HENRY, LORD JOHN OF LANCASTER, the EARL of " - "WESTMORELAND, SIR... " # noqa - == doc.txt - ) +@pytest.mark.redismod +def test_filters(client): + client.ft().create_index( + (TextField("txt"), + NumericField("num"), + GeoField("loc")) + ) + client.ft().add_document( + "doc1", + txt="foo bar", + num=3.141, + loc="-0.441,51.458" + ) + client.ft().add_document("doc2", txt="foo baz", num=2, loc="-0.1,51.2") + + waitForIndex(client, "idx") + # Test numerical filter + q1 = Query("foo").add_filter(NumericFilter("num", 0, 2)).no_content() + q2 = ( + Query("foo") + .add_filter( + NumericFilter("num", 2, NumericFilter.INF, minExclusive=True)) + .no_content() + ) + res1, res2 = client.ft().search(q1), client.ft().search(q2) - @skip_ifmodversion_lt("2.0.0", "search") - def test_alias(self): - index1 = getClient() - index2 = getClient() + assert 1 == res1.total + assert 1 == res2.total + assert "doc2" == res1.docs[0].id + assert "doc1" == res2.docs[0].id - index1.hset("index1:lonestar", mapping={"name": "lonestar"}) - index2.hset("index2:yogurt", mapping={"name": "yogurt"}) + # Test geo filter + q1 = Query("foo").add_filter( + GeoFilter("loc", -0.44, 51.45, 10)).no_content() + q2 = Query("foo").add_filter( + GeoFilter("loc", -0.44, 51.45, 100)).no_content() + res1, res2 = client.ft().search(q1), client.ft().search(q2) - if os.environ.get("GITHUB_WORKFLOW", None) is not None: - time.sleep(2) - else: - time.sleep(5) + assert 1 == res1.total + assert 2 == res2.total + assert "doc1" == res1.docs[0].id - def1 = IndexDefinition(prefix=["index1:"], score_field="name") - def2 = IndexDefinition(prefix=["index2:"], score_field="name") + # Sort results, after RDB reload order may change + res = [res2.docs[0].id, res2.docs[1].id] + res.sort() + assert ["doc1", "doc2"] == res - ftindex1 = index1.ft("testAlias") - ftindex2 = index1.ft("testAlias2") - ftindex1.create_index((TextField("name"),), definition=def1) - ftindex2.create_index((TextField("name"),), definition=def2) - # CI is slower - try: - res = ftindex1.search("*").docs[0] - except IndexError: - time.sleep(5) - res = ftindex1.search("*").docs[0] - assert "index1:lonestar" == res.id - - # create alias and check for results - ftindex1.aliasadd("spaceballs") - alias_client = getClient().ft("spaceballs") - res = alias_client.search("*").docs[0] - assert "index1:lonestar" == res.id - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - ftindex2.aliasadd("spaceballs") - - # update alias and ensure new results - ftindex2.aliasupdate("spaceballs") - alias_client2 = getClient().ft("spaceballs") - res = alias_client2.search("*").docs[0] - assert "index2:yogurt" == res.id - - ftindex2.aliasdel("spaceballs") - with pytest.raises(Exception): - alias_client2.search("*").docs[0] - - def test_alias_basic(self): - # Creating a client with one index - getClient().flushdb() - index1 = getClient().ft("testAlias") - - index1.create_index((TextField("txt"),)) - index1.add_document("doc1", txt="text goes here") - - index2 = getClient().ft("testAlias2") - index2.create_index((TextField("txt"),)) - index2.add_document("doc2", txt="text goes here") - - # add the actual alias and check - index1.aliasadd("myalias") - alias_client = getClient().ft("myalias") - res = sorted(alias_client.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id - - # Throw an exception when trying to add an alias that already exists - with pytest.raises(Exception): - index2.aliasadd("myalias") - - # update the alias and ensure we get doc2 - index2.aliasupdate("myalias") - alias_client2 = getClient().ft("myalias") - res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) - assert "doc1" == res[0].id - - # delete the alias and expect an error if we try to query again - index2.aliasdel("myalias") - with pytest.raises(Exception): - _ = alias_client2.search("*").docs[0] - - def test_tags(self, client): - client.ft().create_index((TextField("txt"), TagField("tags"))) - tags = "foo,foo bar,hello;world" - tags2 = "soba,ramen" - - client.ft().add_document("doc1", txt="fooz barz", tags=tags) - client.ft().add_document("doc2", txt="noodles", tags=tags2) - waitForIndex(client, "idx") - - q = Query("@tags:{foo}") - res = client.ft().search(q) - assert 1 == res.total - - q = Query("@tags:{foo bar}") - res = client.ft().search(q) - assert 1 == res.total - - q = Query("@tags:{foo\\ bar}") - res = client.ft().search(q) - assert 1 == res.total - - q = Query("@tags:{hello\\;world}") - res = client.ft().search(q) - assert 1 == res.total - - q2 = client.ft().tagvals("tags") - assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() - - def test_textfield_sortable_nostem(self, client): - # Creating the index definition with sortable and no_stem - client.ft().create_index( - (TextField("txt", sortable=True, no_stem=True),)) - - # Now get the index info to confirm its contents - response = client.ft().info() - assert "SORTABLE" in response["attributes"][0] - assert "NOSTEM" in response["attributes"][0] - - def test_alter_schema_add(self, client): - # Creating the index definition and schema - client.ft().create_index(TextField("title")) - - # Using alter to add a field - client.ft().alter_schema_add(TextField("body")) - - # Indexing a document - client.ft().add_document( - "doc1", title="MyTitle", body="Some content only in the body" - ) +@pytest.mark.redismod +def test_payloads_with_no_content(client): + client.ft().create_index((TextField("txt"),)) + client.ft().add_document("doc1", payload="foo baz", txt="foo bar") + client.ft().add_document("doc2", payload="foo baz2", txt="foo bar") - # Searching with parameter only in the body (the added field) - q = Query("only in the body") + q = Query("foo bar").with_payloads().no_content() + res = client.ft().search(q) + assert 2 == len(res.docs) - # Ensure we find the result searching on the added body field - res = client.ft().search(q) - assert 1 == res.total - def test_spell_check(self, client): - client.ft().create_index((TextField("f1"), TextField("f2"))) +@pytest.mark.redismod +def test_sort_by(client): + client.ft().create_index( + (TextField("txt"), + NumericField("num", sortable=True)) + ) + client.ft().add_document("doc1", txt="foo bar", num=1) + client.ft().add_document("doc2", txt="foo baz", num=2) + client.ft().add_document("doc3", txt="foo qux", num=3) - client.ft().add_document( - "doc1", - f1="some valid content", - f2="this is sample text" - ) - client.ft().add_document("doc2", f1="very important", f2="lorem ipsum") - waitForIndex(client, "idx") - - # test spellcheck - res = client.ft().spellcheck("impornant") - assert "important" == res["impornant"][0]["suggestion"] - - res = client.ft().spellcheck("contnt") - assert "content" == res["contnt"][0]["suggestion"] - - # test spellcheck with Levenshtein distance - res = client.ft().spellcheck("vlis") - assert res == {} - res = client.ft().spellcheck("vlis", distance=2) - assert "valid" == res["vlis"][0]["suggestion"] - - # test spellcheck include - client.ft().dict_add("dict", "lore", "lorem", "lorm") - res = client.ft().spellcheck("lorm", include="dict") - assert len(res["lorm"]) == 3 - assert ( - res["lorm"][0]["suggestion"], - res["lorm"][1]["suggestion"], - res["lorm"][2]["suggestion"], - ) == ("lorem", "lore", "lorm") - assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ( - "0.5", "0") - - # test spellcheck exclude - res = client.ft().spellcheck("lorm", exclude="dict") - assert res == {} - - def test_dict_operations(self, client): - client.ft().create_index((TextField("f1"), TextField("f2"))) - # Add three items - res = client.ft().dict_add("custom_dict", "item1", "item2", "item3") - assert 3 == res - - # Remove one item - res = client.ft().dict_del("custom_dict", "item2") - assert 1 == res - - # Dump dict and inspect content - res = client.ft().dict_dump("custom_dict") - assert ["item1", "item3"] == res - - # Remove rest of the items before reload - client.ft().dict_del("custom_dict", *res) - - def test_phonetic_matcher(self, client): - client.ft().create_index((TextField("name"),)) - client.ft().add_document("doc1", name="Jon") - client.ft().add_document("doc2", name="John") - - res = client.ft().search(Query("Jon")) - assert 1 == len(res.docs) - assert "Jon" == res.docs[0].name - - # Drop and create index with phonetic matcher - client.flushdb() - - client.ft().create_index( - (TextField("name", phonetic_matcher="dm:en"),)) - client.ft().add_document("doc1", name="Jon") - client.ft().add_document("doc2", name="John") - - res = client.ft().search(Query("Jon")) - assert 2 == len(res.docs) - assert ["John", "Jon"] == sorted([d.name for d in res.docs]) - - def test_scorer(self, client): - client.ft().create_index((TextField("description"),)) + # Test sort + q1 = Query("foo").sort_by("num", asc=True).no_content() + q2 = Query("foo").sort_by("num", asc=False).no_content() + res1, res2 = client.ft().search(q1), client.ft().search(q2) - client.ft().add_document( - "doc1", description="The quick brown fox jumps over the lazy dog" - ) - client.ft().add_document( - "doc2", - description="Quick alice was beginning to get very tired of " - "sitting by her quick sister on the bank, and of " - "having nothing to do.", - # noqa - ) + assert 3 == res1.total + assert "doc1" == res1.docs[0].id + assert "doc2" == res1.docs[1].id + assert "doc3" == res1.docs[2].id + assert 3 == res2.total + assert "doc1" == res2.docs[2].id + assert "doc2" == res2.docs[1].id + assert "doc3" == res2.docs[0].id - # default scorer is TFIDF - res = client.ft().search(Query("quick").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search( - Query("quick").scorer("TFIDF.DOCNORM").with_scores()) - assert 0.1111111111111111 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("BM25").with_scores()) - assert 0.17699114465425977 == res.docs[0].score - res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) - assert 2.0 == res.docs[0].score - res = client.ft().search( - Query("quick").scorer("DOCSCORE").with_scores()) - assert 1.0 == res.docs[0].score - res = client.ft().search( - Query("quick").scorer("HAMMING").with_scores()) - assert 0.0 == res.docs[0].score - - def test_get(self, client): - client.ft().create_index((TextField("f1"), TextField("f2"))) - - assert [None] == client.ft().get("doc1") - assert [None, None] == client.ft().get("doc2", "doc1") - client.ft().add_document( - "doc1", f1="some valid content dd1", f2="this is sample text ff1" - ) - client.ft().add_document( - "doc2", f1="some valid content dd2", f2="this is sample text ff2" - ) +@pytest.mark.redismod +@skip_ifmodversion_lt("2.0.0", "search") +def test_drop_index(): + """ + Ensure the index gets dropped by data remains by default + """ + for x in range(20): + for keep_docs in [[True, {}], [False, {"name": "haveit"}]]: + idx = "HaveIt" + index = getClient() + index.hset("index:haveit", mapping={"name": "haveit"}) + idef = IndexDefinition(prefix=["index:"]) + index.ft(idx).create_index((TextField("name"),), definition=idef) + waitForIndex(index, idx) + index.ft(idx).dropindex(delete_documents=keep_docs[0]) + i = index.hgetall("index:haveit") + assert i == keep_docs[1] - assert [ - ["f1", "some valid content dd2", "f2", - "this is sample text ff2"] - ] == client.ft().get("doc2") - assert [ - ["f1", "some valid content dd1", "f2", - "this is sample text ff1"], - ["f1", "some valid content dd2", "f2", - "this is sample text ff2"], - ] == client.ft().get("doc1", "doc2") - - @skip_ifmodversion_lt("2.2.0", "search") - def test_config(self, client): - assert client.ft().config_set("TIMEOUT", "100") - with pytest.raises(redis.ResponseError): - client.ft().config_set("TIMEOUT", "null") - res = client.ft().config_get("*") - assert "100" == res["TIMEOUT"] - res = client.ft().config_get("TIMEOUT") - assert "100" == res["TIMEOUT"] - - def test_aggregations(self, client): - # Creating the index definition and schema - client.ft().create_index( - ( - NumericField("random_num"), - TextField("title"), - TextField("body"), - TextField("parent"), - ) - ) - # Indexing a document - client.ft().add_document( - "search", - title="RediSearch", - body="Redisearch impements a search engine on top of redis", - parent="redis", - random_num=10, - ) - client.ft().add_document( - "ai", - title="RedisAI", - body="RedisAI executes Deep Learning/Machine Learning models and" - " managing their data.", - # noqa - parent="redis", - random_num=3, - ) - client.ft().add_document( - "json", - title="RedisJson", - body="RedisJSON implements ECMA-404 The JSON Data Interchange " - "Standard as a native data type.", - # noqa - parent="redis", - random_num=8, - ) +@pytest.mark.redismod +def test_example(client): + # Creating the index definition and schema + client.ft().create_index( + (TextField("title", weight=5.0), + TextField("body")) + ) - req = aggregations.AggregateRequest("redis").group_by( - "@parent", - reducers.count(), - reducers.count_distinct("@title"), - reducers.count_distinctish("@title"), - reducers.sum("@random_num"), - reducers.min("@random_num"), - reducers.max("@random_num"), - reducers.avg("@random_num"), - reducers.stddev("random_num"), - reducers.quantile("@random_num", 0.5), - reducers.tolist("@title"), - reducers.first_value("@title"), - reducers.random_sample("@title", 2), - ) + # Indexing a document + client.ft().add_document( + "doc1", + title="RediSearch", + body="Redisearch impements a search engine on top of redis", + ) - res = client.ft().aggregate(req) - - res = res.rows[0] - assert len(res) == 26 - assert "redis" == res[1] - assert "3" == res[3] - assert "3" == res[5] - assert "3" == res[7] - assert "21" == res[9] - assert "3" == res[11] - assert "10" == res[13] - assert "7" == res[15] - assert "3.60555127546" == res[17] - assert "10" == res[19] - assert ["RediSearch", "RedisAI", "RedisJson"] == res[21] - assert "RediSearch" == res[23] - assert 2 == len(res[25]) - - @skip_ifmodversion_lt("2.0.0", "search") - def test_index_definition(self, client): - """ - Create definition and test its args - """ - with pytest.raises(RuntimeError): - IndexDefinition(prefix=["hset:", "henry"], index_type="json") - - definition = IndexDefinition( - prefix=["hset:", "henry"], - filter="@f1==32", - language="English", - language_field="play", - score_field="chapter", - score=0.5, - payload_field="txt", - index_type=IndexType.JSON, - ) + # Searching with complex parameters: + q = Query("search engine").verbatim().no_content().paging(0, 5) - assert [ - "ON", - "JSON", - "PREFIX", - 2, - "hset:", - "henry", - "FILTER", - "@f1==32", - "LANGUAGE_FIELD", - "play", - "LANGUAGE", - "English", - "SCORE_FIELD", - "chapter", - "SCORE", - 0.5, - "PAYLOAD_FIELD", - "txt", - ] == definition.args - - createIndex(client.ft(), num_docs=500, definition=definition) - - @skip_ifmodversion_lt("2.0.0", "search") - def test_create_client_definition(self, client): - """ - Create definition with no index type provided, - and use hset to test the client definition (the default is HASH). - """ - definition = IndexDefinition(prefix=["hset:", "henry"]) - createIndex(client.ft(), num_docs=500, definition=definition) - - info = client.ft().info() - assert 494 == int(info["num_docs"]) - - client.ft().client.hset("hset:1", "f1", "v1") - info = client.ft().info() - assert 495 == int(info["num_docs"]) - - @skip_ifmodversion_lt("2.0.0", "search") - def test_create_client_definition_hash(self, client): - """ - Create definition with IndexType.HASH as index type (ON HASH), - and use hset to test the client definition. - """ - definition = IndexDefinition( - prefix=["hset:", "henry"], - index_type=IndexType.HASH - ) - createIndex(client.ft(), num_docs=500, definition=definition) - - info = client.ft().info() - assert 494 == int(info["num_docs"]) - - client.ft().client.hset("hset:1", "f1", "v1") - info = client.ft().info() - assert 495 == int(info["num_docs"]) - - @skip_ifmodversion_lt("2.2.0", "search") - def test_create_client_definition_json(self, client): - """ - Create definition with IndexType.JSON as index type (ON JSON), - and use json client to test it. - """ - definition = IndexDefinition(prefix=["king:"], - index_type=IndexType.JSON) - client.ft().create_index((TextField("$.name"),), definition=definition) - - client.json().set("king:1", Path.rootPath(), {"name": "henry"}) - client.json().set("king:2", Path.rootPath(), {"name": "james"}) - - res = client.ft().search("henry") - assert res.docs[0].id == "king:1" - assert res.docs[0].payload is None - assert res.docs[0].json == '{"name":"henry"}' - assert res.total == 1 - - @skip_ifmodversion_lt("2.2.0", "search") - def test_fields_as_name(self, client): - # create index - SCHEMA = ( - TextField("$.name", sortable=True, as_name="name"), - NumericField("$.age", as_name="just_a_number"), - ) - definition = IndexDefinition(index_type=IndexType.JSON) - client.ft().create_index(SCHEMA, definition=definition) - - # insert json data - res = client.json().set( - "doc:1", - Path.rootPath(), - {"name": "Jon", "age": 25} - ) - assert res - - total = client.ft().search( - Query("Jon").return_fields("name", "just_a_number")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "Jon" == total[0].name - assert "25" == total[0].just_a_number - - @skip_ifmodversion_lt("2.2.0", "search") - def test_search_return_fields(self, client): - res = client.json().set( - "doc:1", - Path.rootPath(), - {"t": "riceratops", "t2": "telmatosaurus", "n": 9072, "flt": 97.2}, - ) - assert res + res = client.ft().search(q) + assert res is not None - # create index on - definition = IndexDefinition(index_type=IndexType.JSON) - SCHEMA = ( - TextField("$.t"), - NumericField("$.flt"), - ) - client.ft().create_index(SCHEMA, definition=definition) - waitForIndex(client, "idx") - - total = client.ft().search( - Query("*").return_field("$.t", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "riceratops" == total[0].txt - - total = client.ft().search( - Query("*").return_field("$.t2", as_field="txt")).docs - assert 1 == len(total) - assert "doc:1" == total[0].id - assert "telmatosaurus" == total[0].txt - - def test_synupdate(self, client): - definition = IndexDefinition(index_type=IndexType.HASH) - client.ft().create_index( - ( - TextField("title"), - TextField("body"), - ), - definition=definition, + +@pytest.mark.redismod +def test_auto_complete(client): + n = 0 + with open(TITLES_CSV) as f: + cr = csv.reader(f) + + for row in cr: + n += 1 + term, score = row[0], float(row[1]) + assert n == client.ft().sugadd("ac", Suggestion(term, score=score)) + + assert n == client.ft().suglen("ac") + ret = client.ft().sugget("ac", "bad", with_scores=True) + assert 2 == len(ret) + assert "badger" == ret[0].string + assert isinstance(ret[0].score, float) + assert 1.0 != ret[0].score + assert "badalte rishtey" == ret[1].string + assert isinstance(ret[1].score, float) + assert 1.0 != ret[1].score + + ret = client.ft().sugget("ac", "bad", fuzzy=True, num=10) + assert 10 == len(ret) + assert 1.0 == ret[0].score + strs = {x.string for x in ret} + + for sug in strs: + assert 1 == client.ft().sugdel("ac", sug) + # make sure a second delete returns 0 + for sug in strs: + assert 0 == client.ft().sugdel("ac", sug) + + # make sure they were actually deleted + ret2 = client.ft().sugget("ac", "bad", fuzzy=True, num=10) + for sug in ret2: + assert sug.string not in strs + + # Test with payload + client.ft().sugadd("ac", Suggestion("pay1", payload="pl1")) + client.ft().sugadd("ac", Suggestion("pay2", payload="pl2")) + client.ft().sugadd("ac", Suggestion("pay3", payload="pl3")) + + sugs = client.ft().sugget( + "ac", + "pay", + with_payloads=True, + with_scores=True + ) + assert 3 == len(sugs) + for sug in sugs: + assert sug.payload + assert sug.payload.startswith("pl") + + +@pytest.mark.redismod +def test_no_index(client): + client.ft().create_index( + ( + TextField("field"), + TextField("text", no_index=True, sortable=True), + NumericField("numeric", no_index=True, sortable=True), + GeoField("geo", no_index=True, sortable=True), + TagField("tag", no_index=True, sortable=True), ) + ) - client.ft().synupdate("id1", True, "boy", "child", "offspring") - client.ft().add_document( - "doc1", - title="he is a baby", - body="this is a test") + client.ft().add_document( + "doc1", field="aaa", text="1", numeric="1", geo="1,1", tag="1" + ) + client.ft().add_document( + "doc2", field="aab", text="2", numeric="2", geo="2,2", tag="2" + ) + waitForIndex(client, "idx") + + res = client.ft().search(Query("@text:aa*")) + assert 0 == res.total + + res = client.ft().search(Query("@field:aa*")) + assert 2 == res.total + + res = client.ft().search(Query("*").sort_by("text", asc=False)) + assert 2 == res.total + assert "doc2" == res.docs[0].id + + res = client.ft().search(Query("*").sort_by("text", asc=True)) + assert "doc1" == res.docs[0].id + + res = client.ft().search(Query("*").sort_by("numeric", asc=True)) + assert "doc1" == res.docs[0].id + + res = client.ft().search(Query("*").sort_by("geo", asc=True)) + assert "doc1" == res.docs[0].id + + res = client.ft().search(Query("*").sort_by("tag", asc=True)) + assert "doc1" == res.docs[0].id + + # Ensure exception is raised for non-indexable, non-sortable fields + with pytest.raises(Exception): + TextField("name", no_index=True, sortable=False) + with pytest.raises(Exception): + NumericField("name", no_index=True, sortable=False) + with pytest.raises(Exception): + GeoField("name", no_index=True, sortable=False) + with pytest.raises(Exception): + TagField("name", no_index=True, sortable=False) + + +@pytest.mark.redismod +def test_partial(client): + client.ft().create_index( + (TextField("f1"), + TextField("f2"), + TextField("f3")) + ) + client.ft().add_document("doc1", f1="f1_val", f2="f2_val") + client.ft().add_document("doc2", f1="f1_val", f2="f2_val") + client.ft().add_document("doc1", f3="f3_val", partial=True) + client.ft().add_document("doc2", f3="f3_val", replace=True) + waitForIndex(client, "idx") + + # Search for f3 value. All documents should have it + res = client.ft().search("@f3:f3_val") + assert 2 == res.total - client.ft().synupdate("id1", True, "baby") + # Only the document updated with PARTIAL should still have f1 and f2 values + res = client.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") + assert 1 == res.total + + +@pytest.mark.redismod +def test_no_create(client): + client.ft().create_index( + (TextField("f1"), + TextField("f2"), + TextField("f3")) + ) + client.ft().add_document("doc1", f1="f1_val", f2="f2_val") + client.ft().add_document("doc2", f1="f1_val", f2="f2_val") + client.ft().add_document("doc1", f3="f3_val", no_create=True) + client.ft().add_document("doc2", f3="f3_val", no_create=True, partial=True) + waitForIndex(client, "idx") + + # Search for f3 value. All documents should have it + res = client.ft().search("@f3:f3_val") + assert 2 == res.total + + # Only the document updated with PARTIAL should still have f1 and f2 values + res = client.ft().search("@f3:f3_val @f2:f2_val @f1:f1_val") + assert 1 == res.total + + with pytest.raises(redis.ResponseError): client.ft().add_document( - "doc2", - title="he is another baby", - body="another test" + "doc3", + f2="f2_val", + f3="f3_val", + no_create=True ) - res = client.ft().search(Query("child").expander("SYNONYM")) - assert res.docs[0].id == "doc2" - assert res.docs[0].title == "he is another baby" - assert res.docs[0].body == "another test" - - def test_syndump(self, client): - definition = IndexDefinition(index_type=IndexType.HASH) - client.ft().create_index( - ( - TextField("title"), - TextField("body"), - ), - definition=definition, + +@pytest.mark.redismod +def test_explain(client): + client.ft().create_index( + (TextField("f1"), + TextField("f2"), + TextField("f3")) + ) + res = client.ft().explain("@f3:f3_val @f2:f2_val @f1:f1_val") + assert res + + +@pytest.mark.redismod +def test_summarize(client): + createIndex(client.ft()) + waitForIndex(client, "idx") + + q = Query("king henry").paging(0, 1) + q.highlight(fields=("play", "txt"), tags=("", "")) + q.summarize("txt") + + doc = sorted(client.ft().search(q).docs)[0] + assert "Henry IV" == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) + + q = Query("king henry").paging(0, 1).summarize().highlight() + + doc = sorted(client.ft().search(q).docs)[0] + assert "Henry ... " == doc.play + assert ( + "ACT I SCENE I. London. The palace. Enter KING HENRY, LORD JOHN OF LANCASTER, the EARL of WESTMORELAND, SIR... " # noqa + == doc.txt + ) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.0.0", "search") +def test_alias(): + index1 = getClient() + index2 = getClient() + + index1.hset("index1:lonestar", mapping={"name": "lonestar"}) + index2.hset("index2:yogurt", mapping={"name": "yogurt"}) + + if os.environ.get("GITHUB_WORKFLOW", None) is not None: + time.sleep(2) + else: + time.sleep(5) + + def1 = IndexDefinition(prefix=["index1:"], score_field="name") + def2 = IndexDefinition(prefix=["index2:"], score_field="name") + + ftindex1 = index1.ft("testAlias") + ftindex2 = index1.ft("testAlias2") + ftindex1.create_index((TextField("name"),), definition=def1) + ftindex2.create_index((TextField("name"),), definition=def2) + + # CI is slower + try: + res = ftindex1.search("*").docs[0] + except IndexError: + time.sleep(5) + res = ftindex1.search("*").docs[0] + assert "index1:lonestar" == res.id + + # create alias and check for results + ftindex1.aliasadd("spaceballs") + alias_client = getClient().ft("spaceballs") + res = alias_client.search("*").docs[0] + assert "index1:lonestar" == res.id + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + ftindex2.aliasadd("spaceballs") + + # update alias and ensure new results + ftindex2.aliasupdate("spaceballs") + alias_client2 = getClient().ft("spaceballs") + res = alias_client2.search("*").docs[0] + assert "index2:yogurt" == res.id + + ftindex2.aliasdel("spaceballs") + with pytest.raises(Exception): + alias_client2.search("*").docs[0] + + +@pytest.mark.redismod +def test_alias_basic(): + # Creating a client with one index + getClient().flushdb() + index1 = getClient().ft("testAlias") + + index1.create_index((TextField("txt"),)) + index1.add_document("doc1", txt="text goes here") + + index2 = getClient().ft("testAlias2") + index2.create_index((TextField("txt"),)) + index2.add_document("doc2", txt="text goes here") + + # add the actual alias and check + index1.aliasadd("myalias") + alias_client = getClient().ft("myalias") + res = sorted(alias_client.search("*").docs, key=lambda x: x.id) + assert "doc1" == res[0].id + + # Throw an exception when trying to add an alias that already exists + with pytest.raises(Exception): + index2.aliasadd("myalias") + + # update the alias and ensure we get doc2 + index2.aliasupdate("myalias") + alias_client2 = getClient().ft("myalias") + res = sorted(alias_client2.search("*").docs, key=lambda x: x.id) + assert "doc1" == res[0].id + + # delete the alias and expect an error if we try to query again + index2.aliasdel("myalias") + with pytest.raises(Exception): + _ = alias_client2.search("*").docs[0] + + +@pytest.mark.redismod +def test_tags(client): + client.ft().create_index((TextField("txt"), TagField("tags"))) + tags = "foo,foo bar,hello;world" + tags2 = "soba,ramen" + + client.ft().add_document("doc1", txt="fooz barz", tags=tags) + client.ft().add_document("doc2", txt="noodles", tags=tags2) + waitForIndex(client, "idx") + + q = Query("@tags:{foo}") + res = client.ft().search(q) + assert 1 == res.total + + q = Query("@tags:{foo bar}") + res = client.ft().search(q) + assert 1 == res.total + + q = Query("@tags:{foo\\ bar}") + res = client.ft().search(q) + assert 1 == res.total + + q = Query("@tags:{hello\\;world}") + res = client.ft().search(q) + assert 1 == res.total + + q2 = client.ft().tagvals("tags") + assert (tags.split(",") + tags2.split(",")).sort() == q2.sort() + + +@pytest.mark.redismod +def test_textfield_sortable_nostem(client): + # Creating the index definition with sortable and no_stem + client.ft().create_index((TextField("txt", sortable=True, no_stem=True),)) + + # Now get the index info to confirm its contents + response = client.ft().info() + assert "SORTABLE" in response["attributes"][0] + assert "NOSTEM" in response["attributes"][0] + + +@pytest.mark.redismod +def test_alter_schema_add(client): + # Creating the index definition and schema + client.ft().create_index(TextField("title")) + + # Using alter to add a field + client.ft().alter_schema_add(TextField("body")) + + # Indexing a document + client.ft().add_document( + "doc1", title="MyTitle", body="Some content only in the body" + ) + + # Searching with parameter only in the body (the added field) + q = Query("only in the body") + + # Ensure we find the result searching on the added body field + res = client.ft().search(q) + assert 1 == res.total + + +@pytest.mark.redismod +def test_spell_check(client): + client.ft().create_index((TextField("f1"), TextField("f2"))) + + client.ft().add_document( + "doc1", + f1="some valid content", + f2="this is sample text" + ) + client.ft().add_document("doc2", f1="very important", f2="lorem ipsum") + waitForIndex(client, "idx") + + # test spellcheck + res = client.ft().spellcheck("impornant") + assert "important" == res["impornant"][0]["suggestion"] + + res = client.ft().spellcheck("contnt") + assert "content" == res["contnt"][0]["suggestion"] + + # test spellcheck with Levenshtein distance + res = client.ft().spellcheck("vlis") + assert res == {} + res = client.ft().spellcheck("vlis", distance=2) + assert "valid" == res["vlis"][0]["suggestion"] + + # test spellcheck include + client.ft().dict_add("dict", "lore", "lorem", "lorm") + res = client.ft().spellcheck("lorm", include="dict") + assert len(res["lorm"]) == 3 + assert ( + res["lorm"][0]["suggestion"], + res["lorm"][1]["suggestion"], + res["lorm"][2]["suggestion"], + ) == ("lorem", "lore", "lorm") + assert (res["lorm"][0]["score"], res["lorm"][1]["score"]) == ("0.5", "0") + + # test spellcheck exclude + res = client.ft().spellcheck("lorm", exclude="dict") + assert res == {} + + +@pytest.mark.redismod +def test_dict_operations(client): + client.ft().create_index((TextField("f1"), TextField("f2"))) + # Add three items + res = client.ft().dict_add("custom_dict", "item1", "item2", "item3") + assert 3 == res + + # Remove one item + res = client.ft().dict_del("custom_dict", "item2") + assert 1 == res + + # Dump dict and inspect content + res = client.ft().dict_dump("custom_dict") + assert ["item1", "item3"] == res + + # Remove rest of the items before reload + client.ft().dict_del("custom_dict", *res) + + +@pytest.mark.redismod +def test_phonetic_matcher(client): + client.ft().create_index((TextField("name"),)) + client.ft().add_document("doc1", name="Jon") + client.ft().add_document("doc2", name="John") + + res = client.ft().search(Query("Jon")) + assert 1 == len(res.docs) + assert "Jon" == res.docs[0].name + + # Drop and create index with phonetic matcher + client.flushdb() + + client.ft().create_index((TextField("name", phonetic_matcher="dm:en"),)) + client.ft().add_document("doc1", name="Jon") + client.ft().add_document("doc2", name="John") + + res = client.ft().search(Query("Jon")) + assert 2 == len(res.docs) + assert ["John", "Jon"] == sorted([d.name for d in res.docs]) + + +@pytest.mark.redismod +def test_scorer(client): + client.ft().create_index((TextField("description"),)) + + client.ft().add_document( + "doc1", description="The quick brown fox jumps over the lazy dog" + ) + client.ft().add_document( + "doc2", + description="Quick alice was beginning to get very tired of sitting by her quick sister on the bank, and of having nothing to do.", # noqa + ) + + # default scorer is TFIDF + res = client.ft().search(Query("quick").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("TFIDF").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search( + Query("quick").scorer("TFIDF.DOCNORM").with_scores()) + assert 0.1111111111111111 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("BM25").with_scores()) + assert 0.17699114465425977 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DISMAX").with_scores()) + assert 2.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("DOCSCORE").with_scores()) + assert 1.0 == res.docs[0].score + res = client.ft().search(Query("quick").scorer("HAMMING").with_scores()) + assert 0.0 == res.docs[0].score + + +@pytest.mark.redismod +def test_get(client): + client.ft().create_index((TextField("f1"), TextField("f2"))) + + assert [None] == client.ft().get("doc1") + assert [None, None] == client.ft().get("doc2", "doc1") + + client.ft().add_document( + "doc1", f1="some valid content dd1", f2="this is sample text ff1" + ) + client.ft().add_document( + "doc2", f1="some valid content dd2", f2="this is sample text ff2" + ) + + assert [ + ["f1", "some valid content dd2", "f2", "this is sample text ff2"] + ] == client.ft().get("doc2") + assert [ + ["f1", "some valid content dd1", "f2", "this is sample text ff1"], + ["f1", "some valid content dd2", "f2", "this is sample text ff2"], + ] == client.ft().get("doc1", "doc2") + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.2.0", "search") +def test_config(client): + assert client.ft().config_set("TIMEOUT", "100") + with pytest.raises(redis.ResponseError): + client.ft().config_set("TIMEOUT", "null") + res = client.ft().config_get("*") + assert "100" == res["TIMEOUT"] + res = client.ft().config_get("TIMEOUT") + assert "100" == res["TIMEOUT"] + + +@pytest.mark.redismod +def test_aggregations(client): + # Creating the index definition and schema + client.ft().create_index( + ( + NumericField("random_num"), + TextField("title"), + TextField("body"), + TextField("parent"), ) + ) + + # Indexing a document + client.ft().add_document( + "search", + title="RediSearch", + body="Redisearch impements a search engine on top of redis", + parent="redis", + random_num=10, + ) + client.ft().add_document( + "ai", + title="RedisAI", + body="RedisAI executes Deep Learning/Machine Learning models and managing their data.", # noqa + parent="redis", + random_num=3, + ) + client.ft().add_document( + "json", + title="RedisJson", + body="RedisJSON implements ECMA-404 The JSON Data Interchange Standard as a native data type.", # noqa + parent="redis", + random_num=8, + ) + + req = aggregations.AggregateRequest("redis").group_by( + "@parent", + reducers.count(), + reducers.count_distinct("@title"), + reducers.count_distinctish("@title"), + reducers.sum("@random_num"), + reducers.min("@random_num"), + reducers.max("@random_num"), + reducers.avg("@random_num"), + reducers.stddev("random_num"), + reducers.quantile("@random_num", 0.5), + reducers.tolist("@title"), + reducers.first_value("@title"), + reducers.random_sample("@title", 2), + ) + + res = client.ft().aggregate(req) + + res = res.rows[0] + assert len(res) == 26 + assert "redis" == res[1] + assert "3" == res[3] + assert "3" == res[5] + assert "3" == res[7] + assert "21" == res[9] + assert "3" == res[11] + assert "10" == res[13] + assert "7" == res[15] + assert "3.60555127546" == res[17] + assert "10" == res[19] + assert ["RediSearch", "RedisAI", "RedisJson"] == res[21] + assert "RediSearch" == res[23] + assert 2 == len(res[25]) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.0.0", "search") +def test_index_definition(client): + """ + Create definition and test its args + """ + with pytest.raises(RuntimeError): + IndexDefinition(prefix=["hset:", "henry"], index_type="json") + + definition = IndexDefinition( + prefix=["hset:", "henry"], + filter="@f1==32", + language="English", + language_field="play", + score_field="chapter", + score=0.5, + payload_field="txt", + index_type=IndexType.JSON, + ) + + assert [ + "ON", + "JSON", + "PREFIX", + 2, + "hset:", + "henry", + "FILTER", + "@f1==32", + "LANGUAGE_FIELD", + "play", + "LANGUAGE", + "English", + "SCORE_FIELD", + "chapter", + "SCORE", + 0.5, + "PAYLOAD_FIELD", + "txt", + ] == definition.args + + createIndex(client.ft(), num_docs=500, definition=definition) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.0.0", "search") +def test_create_client_definition(client): + """ + Create definition with no index type provided, + and use hset to test the client definition (the default is HASH). + """ + definition = IndexDefinition(prefix=["hset:", "henry"]) + createIndex(client.ft(), num_docs=500, definition=definition) + + info = client.ft().info() + assert 494 == int(info["num_docs"]) + + client.ft().client.hset("hset:1", "f1", "v1") + info = client.ft().info() + assert 495 == int(info["num_docs"]) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.0.0", "search") +def test_create_client_definition_hash(client): + """ + Create definition with IndexType.HASH as index type (ON HASH), + and use hset to test the client definition. + """ + definition = IndexDefinition( + prefix=["hset:", "henry"], + index_type=IndexType.HASH + ) + createIndex(client.ft(), num_docs=500, definition=definition) + + info = client.ft().info() + assert 494 == int(info["num_docs"]) + + client.ft().client.hset("hset:1", "f1", "v1") + info = client.ft().info() + assert 495 == int(info["num_docs"]) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.2.0", "search") +def test_create_client_definition_json(client): + """ + Create definition with IndexType.JSON as index type (ON JSON), + and use json client to test it. + """ + definition = IndexDefinition(prefix=["king:"], index_type=IndexType.JSON) + client.ft().create_index((TextField("$.name"),), definition=definition) + + client.json().set("king:1", Path.rootPath(), {"name": "henry"}) + client.json().set("king:2", Path.rootPath(), {"name": "james"}) + + res = client.ft().search("henry") + assert res.docs[0].id == "king:1" + assert res.docs[0].payload is None + assert res.docs[0].json == '{"name":"henry"}' + assert res.total == 1 + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.2.0", "search") +def test_fields_as_name(client): + # create index + SCHEMA = ( + TextField("$.name", sortable=True, as_name="name"), + NumericField("$.age", as_name="just_a_number"), + ) + definition = IndexDefinition(index_type=IndexType.JSON) + client.ft().create_index(SCHEMA, definition=definition) + + # insert json data + res = client.json().set( + "doc:1", + Path.rootPath(), + {"name": "Jon", "age": 25} + ) + assert res + + total = client.ft().search( + Query("Jon").return_fields("name", "just_a_number")).docs + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "Jon" == total[0].name + assert "25" == total[0].just_a_number + + +@pytest.mark.redismod +@skip_ifmodversion_lt("2.2.0", "search") +def test_search_return_fields(client): + res = client.json().set( + "doc:1", + Path.rootPath(), + {"t": "riceratops", "t2": "telmatosaurus", "n": 9072, "flt": 97.2}, + ) + assert res + + # create index on + definition = IndexDefinition(index_type=IndexType.JSON) + SCHEMA = ( + TextField("$.t"), + NumericField("$.flt"), + ) + client.ft().create_index(SCHEMA, definition=definition) + waitForIndex(client, "idx") + + total = client.ft().search( + Query("*").return_field("$.t", as_field="txt")).docs + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "riceratops" == total[0].txt + + total = client.ft().search( + Query("*").return_field("$.t2", as_field="txt")).docs + assert 1 == len(total) + assert "doc:1" == total[0].id + assert "telmatosaurus" == total[0].txt + + +@pytest.mark.redismod +def test_synupdate(client): + definition = IndexDefinition(index_type=IndexType.HASH) + client.ft().create_index( + ( + TextField("title"), + TextField("body"), + ), + definition=definition, + ) + + client.ft().synupdate("id1", True, "boy", "child", "offspring") + client.ft().add_document( + "doc1", + title="he is a baby", + body="this is a test") + + client.ft().synupdate("id1", True, "baby") + client.ft().add_document( + "doc2", + title="he is another baby", + body="another test" + ) + + res = client.ft().search(Query("child").expander("SYNONYM")) + assert res.docs[0].id == "doc2" + assert res.docs[0].title == "he is another baby" + assert res.docs[0].body == "another test" + + +@pytest.mark.redismod +def test_syndump(client): + definition = IndexDefinition(index_type=IndexType.HASH) + client.ft().create_index( + ( + TextField("title"), + TextField("body"), + ), + definition=definition, + ) - client.ft().synupdate("id1", False, "boy", "child", "offspring") - client.ft().synupdate("id2", False, "baby", "child") - client.ft().synupdate("id3", False, "tree", "wood") - res = client.ft().syndump() - assert res == { - "boy": ["id1"], - "tree": ["id3"], - "wood": ["id3"], - "child": ["id1", "id2"], - "baby": ["id2"], - "offspring": ["id1"], - } + client.ft().synupdate("id1", False, "boy", "child", "offspring") + client.ft().synupdate("id2", False, "baby", "child") + client.ft().synupdate("id3", False, "tree", "wood") + res = client.ft().syndump() + assert res == { + "boy": ["id1"], + "tree": ["id3"], + "wood": ["id3"], + "child": ["id1", "id2"], + "baby": ["id2"], + "offspring": ["id1"], + } diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 1d55ea1215..29a8f25855 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -53,6 +53,7 @@ def sentinel(request, cluster): return Sentinel([('foo', 26379), ('bar', 26379)]) +@pytest.mark.onlynoncluster class SentinelTestCluster: def __init__(self, servisentinel_ce_name='mymaster', ip='127.0.0.1', port=6379): diff --git a/tests/test_timeseries.py b/tests/test_timeseries.py index 941f2f9c5f..b2df3feda5 100644 --- a/tests/test_timeseries.py +++ b/tests/test_timeseries.py @@ -11,570 +11,583 @@ def client(modclient): @pytest.mark.redismod -class TestTimeseries: - @pytest.mark.redismod - def testCreate(self, client): - assert client.ts().create(1) - assert client.ts().create(2, retention_msecs=5) - assert client.ts().create(3, labels={"Redis": "Labs"}) - assert client.ts().create(4, retention_msecs=20, - labels={"Time": "Series"}) - info = client.ts().info(4) - assert 20 == info.retention_msecs - assert "Series" == info.labels["Time"] - - # Test for a chunk size of 128 Bytes - assert client.ts().create("time-serie-1", chunk_size=128) - info = client.ts().info("time-serie-1") - assert 128, info.chunk_size - - @pytest.mark.redismod - @skip_ifmodversion_lt("1.4.0", "timeseries") - def testCreateDuplicatePolicy(self, client): - # Test for duplicate policy - for duplicate_policy in ["block", "last", "first", "min", "max"]: - ts_name = "time-serie-ooo-{0}".format(duplicate_policy) - assert client.ts().create(ts_name, - duplicate_policy=duplicate_policy) - info = client.ts().info(ts_name) - assert duplicate_policy == info.duplicate_policy - - @pytest.mark.redismod - def testAlter(self, client): - assert client.ts().create(1) - assert 0 == client.ts().info(1).retention_msecs - assert client.ts().alter(1, retention_msecs=10) - assert {} == client.ts().info(1).labels - assert 10, client.ts().info(1).retention_msecs - assert client.ts().alter(1, labels={"Time": "Series"}) - assert "Series" == client.ts().info(1).labels["Time"] - assert 10 == client.ts().info(1).retention_msecs - - # pipe = client.ts().pipeline() - # assert pipe.create(2) - - @pytest.mark.redismod - @skip_ifmodversion_lt("1.4.0", "timeseries") - def testAlterDiplicatePolicy(self, client): - assert client.ts().create(1) - info = client.ts().info(1) - assert info.duplicate_policy is None - assert client.ts().alter(1, duplicate_policy="min") - info = client.ts().info(1) - assert "min" == info.duplicate_policy - - @pytest.mark.redismod - def testAdd(self, client): - assert 1 == client.ts().add(1, 1, 1) - assert 2 == client.ts().add(2, 2, 3, retention_msecs=10) - assert 3 == client.ts().add(3, 3, 2, labels={"Redis": "Labs"}) - assert 4 == client.ts().add( - 4, 4, 2, retention_msecs=10, - labels={"Redis": "Labs", "Time": "Series"} - ) - assert round(time.time()) == \ - round(float(client.ts().add(5, "*", 1)) / 1000) - - info = client.ts().info(4) - assert 10 == info.retention_msecs - assert "Labs" == info.labels["Redis"] - - # Test for a chunk size of 128 Bytes on TS.ADD - assert client.ts().add("time-serie-1", 1, 10.0, chunk_size=128) - info = client.ts().info("time-serie-1") - assert 128 == info.chunk_size - - @pytest.mark.redismod - @skip_ifmodversion_lt("1.4.0", "timeseries") - def testAddDuplicatePolicy(self, client): - - # Test for duplicate policy BLOCK - assert 1 == client.ts().add("time-serie-add-ooo-block", 1, 5.0) - with pytest.raises(Exception): - client.ts().add( - "time-serie-add-ooo-block", - 1, - 5.0, - duplicate_policy="block" - ) - - # Test for duplicate policy LAST - assert 1 == client.ts().add("time-serie-add-ooo-last", 1, 5.0) - assert 1 == client.ts().add( - "time-serie-add-ooo-last", 1, 10.0, duplicate_policy="last" - ) - assert 10.0 == client.ts().get("time-serie-add-ooo-last")[1] +def testCreate(client): + assert client.ts().create(1) + assert client.ts().create(2, retention_msecs=5) + assert client.ts().create(3, labels={"Redis": "Labs"}) + assert client.ts().create(4, retention_msecs=20, labels={"Time": "Series"}) + info = client.ts().info(4) + assert 20 == info.retention_msecs + assert "Series" == info.labels["Time"] - # Test for duplicate policy FIRST - assert 1 == client.ts().add("time-serie-add-ooo-first", 1, 5.0) - assert 1 == client.ts().add( - "time-serie-add-ooo-first", 1, 10.0, duplicate_policy="first" - ) - assert 5.0 == client.ts().get("time-serie-add-ooo-first")[1] + # Test for a chunk size of 128 Bytes + assert client.ts().create("time-serie-1", chunk_size=128) + info = client.ts().info("time-serie-1") + assert 128, info.chunk_size - # Test for duplicate policy MAX - assert 1 == client.ts().add("time-serie-add-ooo-max", 1, 5.0) - assert 1 == client.ts().add( - "time-serie-add-ooo-max", 1, 10.0, duplicate_policy="max" - ) - assert 10.0 == client.ts().get("time-serie-add-ooo-max")[1] - # Test for duplicate policy MIN - assert 1 == client.ts().add("time-serie-add-ooo-min", 1, 5.0) - assert 1 == client.ts().add( - "time-serie-add-ooo-min", 1, 10.0, duplicate_policy="min" - ) - assert 5.0 == client.ts().get("time-serie-add-ooo-min")[1] - - @pytest.mark.redismod - def testMAdd(self, client): - client.ts().create("a") - assert [1, 2, 3] == \ - client.ts().madd([("a", 1, 5), ("a", 2, 10), ("a", 3, 15)]) - - @pytest.mark.redismod - def testIncrbyDecrby(self, client): - for _ in range(100): - assert client.ts().incrby(1, 1) - sleep(0.001) - assert 100 == client.ts().get(1)[1] - for _ in range(100): - assert client.ts().decrby(1, 1) - sleep(0.001) - assert 0 == client.ts().get(1)[1] - - assert client.ts().incrby(2, 1.5, timestamp=5) - assert (5, 1.5) == client.ts().get(2) - assert client.ts().incrby(2, 2.25, timestamp=7) - assert (7, 3.75) == client.ts().get(2) - assert client.ts().decrby(2, 1.5, timestamp=15) - assert (15, 2.25) == client.ts().get(2) - - # Test for a chunk size of 128 Bytes on TS.INCRBY - assert client.ts().incrby("time-serie-1", 10, chunk_size=128) - info = client.ts().info("time-serie-1") - assert 128 == info.chunk_size - - # Test for a chunk size of 128 Bytes on TS.DECRBY - assert client.ts().decrby("time-serie-2", 10, chunk_size=128) - info = client.ts().info("time-serie-2") - assert 128 == info.chunk_size - - @pytest.mark.redismod - def testCreateAndDeleteRule(self, client): - # test rule creation - time = 100 - client.ts().create(1) - client.ts().create(2) - client.ts().createrule(1, 2, "avg", 100) - for i in range(50): - client.ts().add(1, time + i * 2, 1) - client.ts().add(1, time + i * 2 + 1, 2) - client.ts().add(1, time * 2, 1.5) - assert round(client.ts().get(2)[1], 5) == 1.5 - info = client.ts().info(1) - assert info.rules[0][1] == 100 - - # test rule deletion - client.ts().deleterule(1, 2) - info = client.ts().info(1) - assert not info.rules - - @pytest.mark.redismod - @skip_ifmodversion_lt("99.99.99", "timeseries") - def testDelRange(self, client): - try: - client.ts().delete("test", 0, 100) - except Exception as e: - assert e.__str__() != "" - - for i in range(100): - client.ts().add(1, i, i % 7) - assert 22 == client.ts().delete(1, 0, 21) - assert [] == client.ts().range(1, 0, 21) - assert [(22, 1.0)] == client.ts().range(1, 22, 22) - - @pytest.mark.redismod - def testRange(self, client): - for i in range(100): - client.ts().add(1, i, i % 7) - assert 100 == len(client.ts().range(1, 0, 200)) - for i in range(100): - client.ts().add(1, i + 200, i % 7) - assert 200 == len(client.ts().range(1, 0, 500)) - # last sample isn't returned - assert 20 == len( - client.ts().range( - 1, - 0, - 500, - aggregation_type="avg", - bucket_size_msec=10 - ) - ) - assert 10 == len(client.ts().range(1, 0, 500, count=10)) - - @pytest.mark.redismod - @skip_ifmodversion_lt("99.99.99", "timeseries") - def testRangeAdvanced(self, client): - for i in range(100): - client.ts().add(1, i, i % 7) - client.ts().add(1, i + 200, i % 7) - - assert 2 == len( - client.ts().range( - 1, - 0, - 500, - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - ) - assert [(0, 10.0), (10, 1.0)] == client.ts().range( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" - ) - assert [(-5, 5.0), (5, 6.0)] == client.ts().range( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 - ) +@pytest.mark.redismod +@skip_ifmodversion_lt("1.4.0", "timeseries") +def testCreateDuplicatePolicy(client): + # Test for duplicate policy + for duplicate_policy in ["block", "last", "first", "min", "max"]: + ts_name = "time-serie-ooo-{0}".format(duplicate_policy) + assert client.ts().create(ts_name, duplicate_policy=duplicate_policy) + info = client.ts().info(ts_name) + assert duplicate_policy == info.duplicate_policy - @pytest.mark.redismod - @skip_ifmodversion_lt("99.99.99", "timeseries") - def testRevRange(self, client): - for i in range(100): - client.ts().add(1, i, i % 7) - assert 100 == len(client.ts().range(1, 0, 200)) - for i in range(100): - client.ts().add(1, i + 200, i % 7) - assert 200 == len(client.ts().range(1, 0, 500)) - # first sample isn't returned - assert 20 == len( - client.ts().revrange( - 1, - 0, - 500, - aggregation_type="avg", - bucket_size_msec=10 - ) - ) - assert 10 == len(client.ts().revrange(1, 0, 500, count=10)) - assert 2 == len( - client.ts().revrange( - 1, - 0, - 500, - filter_by_ts=[i for i in range(10, 20)], - filter_by_min_value=1, - filter_by_max_value=2, - ) - ) - assert [(10, 1.0), (0, 10.0)] == client.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" - ) - assert [(1, 10.0), (-9, 1.0)] == client.ts().revrange( - 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 - ) - @pytest.mark.redismod - def testMultiRange(self, client): - client.ts().create(1, labels={"Test": "This", "team": "ny"}) - client.ts().create( - 2, - labels={"Test": "This", "Taste": "That", "team": "sf"} - ) - for i in range(100): - client.ts().add(1, i, i % 7) - client.ts().add(2, i, i % 11) +@pytest.mark.redismod +def testAlter(client): + assert client.ts().create(1) + assert 0 == client.ts().info(1).retention_msecs + assert client.ts().alter(1, retention_msecs=10) + assert {} == client.ts().info(1).labels + assert 10, client.ts().info(1).retention_msecs + assert client.ts().alter(1, labels={"Time": "Series"}) + assert "Series" == client.ts().info(1).labels["Time"] + assert 10 == client.ts().info(1).retention_msecs + + +# pipe = client.ts().pipeline() +# assert pipe.create(2) + + +@pytest.mark.redismod +@skip_ifmodversion_lt("1.4.0", "timeseries") +def testAlterDiplicatePolicy(client): + assert client.ts().create(1) + info = client.ts().info(1) + assert info.duplicate_policy is None + assert client.ts().alter(1, duplicate_policy="min") + info = client.ts().info(1) + assert "min" == info.duplicate_policy + + +@pytest.mark.redismod +def testAdd(client): + assert 1 == client.ts().add(1, 1, 1) + assert 2 == client.ts().add(2, 2, 3, retention_msecs=10) + assert 3 == client.ts().add(3, 3, 2, labels={"Redis": "Labs"}) + assert 4 == client.ts().add( + 4, 4, 2, retention_msecs=10, labels={"Redis": "Labs", "Time": "Series"} + ) + assert round(time.time()) == \ + round(float(client.ts().add(5, "*", 1)) / 1000) + + info = client.ts().info(4) + assert 10 == info.retention_msecs + assert "Labs" == info.labels["Redis"] + + # Test for a chunk size of 128 Bytes on TS.ADD + assert client.ts().add("time-serie-1", 1, 10.0, chunk_size=128) + info = client.ts().info("time-serie-1") + assert 128 == info.chunk_size + + +@pytest.mark.redismod +@skip_ifmodversion_lt("1.4.0", "timeseries") +def testAddDuplicatePolicy(client): + + # Test for duplicate policy BLOCK + assert 1 == client.ts().add("time-serie-add-ooo-block", 1, 5.0) + with pytest.raises(Exception): + client.ts().add( + "time-serie-add-ooo-block", + 1, + 5.0, + duplicate_policy="block" + ) + + # Test for duplicate policy LAST + assert 1 == client.ts().add("time-serie-add-ooo-last", 1, 5.0) + assert 1 == client.ts().add( + "time-serie-add-ooo-last", 1, 10.0, duplicate_policy="last" + ) + assert 10.0 == client.ts().get("time-serie-add-ooo-last")[1] + + # Test for duplicate policy FIRST + assert 1 == client.ts().add("time-serie-add-ooo-first", 1, 5.0) + assert 1 == client.ts().add( + "time-serie-add-ooo-first", 1, 10.0, duplicate_policy="first" + ) + assert 5.0 == client.ts().get("time-serie-add-ooo-first")[1] + + # Test for duplicate policy MAX + assert 1 == client.ts().add("time-serie-add-ooo-max", 1, 5.0) + assert 1 == client.ts().add( + "time-serie-add-ooo-max", 1, 10.0, duplicate_policy="max" + ) + assert 10.0 == client.ts().get("time-serie-add-ooo-max")[1] + + # Test for duplicate policy MIN + assert 1 == client.ts().add("time-serie-add-ooo-min", 1, 5.0) + assert 1 == client.ts().add( + "time-serie-add-ooo-min", 1, 10.0, duplicate_policy="min" + ) + assert 5.0 == client.ts().get("time-serie-add-ooo-min")[1] + + +@pytest.mark.redismod +def testMAdd(client): + client.ts().create("a") + assert [1, 2, 3] == \ + client.ts().madd([("a", 1, 5), ("a", 2, 10), ("a", 3, 15)]) + + +@pytest.mark.redismod +def testIncrbyDecrby(client): + for _ in range(100): + assert client.ts().incrby(1, 1) + sleep(0.001) + assert 100 == client.ts().get(1)[1] + for _ in range(100): + assert client.ts().decrby(1, 1) + sleep(0.001) + assert 0 == client.ts().get(1)[1] + + assert client.ts().incrby(2, 1.5, timestamp=5) + assert (5, 1.5) == client.ts().get(2) + assert client.ts().incrby(2, 2.25, timestamp=7) + assert (7, 3.75) == client.ts().get(2) + assert client.ts().decrby(2, 1.5, timestamp=15) + assert (15, 2.25) == client.ts().get(2) + + # Test for a chunk size of 128 Bytes on TS.INCRBY + assert client.ts().incrby("time-serie-1", 10, chunk_size=128) + info = client.ts().info("time-serie-1") + assert 128 == info.chunk_size + + # Test for a chunk size of 128 Bytes on TS.DECRBY + assert client.ts().decrby("time-serie-2", 10, chunk_size=128) + info = client.ts().info("time-serie-2") + assert 128 == info.chunk_size - res = client.ts().mrange(0, 200, filters=["Test=This"]) - assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) - res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) +@pytest.mark.redismod +def testCreateAndDeleteRule(client): + # test rule creation + time = 100 + client.ts().create(1) + client.ts().create(2) + client.ts().createrule(1, 2, "avg", 100) + for i in range(50): + client.ts().add(1, time + i * 2, 1) + client.ts().add(1, time + i * 2 + 1, 2) + client.ts().add(1, time * 2, 1.5) + assert round(client.ts().get(2)[1], 5) == 1.5 + info = client.ts().info(1) + assert info.rules[0][1] == 100 + + # test rule deletion + client.ts().deleterule(1, 2) + info = client.ts().info(1) + assert not info.rules + + +@pytest.mark.redismod +@skip_ifmodversion_lt("99.99.99", "timeseries") +def testDelRange(client): + try: + client.ts().delete("test", 0, 100) + except Exception as e: + assert e.__str__() != "" - for i in range(100): - client.ts().add(1, i + 200, i % 7) - res = client.ts().mrange( + for i in range(100): + client.ts().add(1, i, i % 7) + assert 22 == client.ts().delete(1, 0, 21) + assert [] == client.ts().range(1, 0, 21) + assert [(22, 1.0)] == client.ts().range(1, 22, 22) + + +@pytest.mark.redismod +def testRange(client): + for i in range(100): + client.ts().add(1, i, i % 7) + assert 100 == len(client.ts().range(1, 0, 200)) + for i in range(100): + client.ts().add(1, i + 200, i % 7) + assert 200 == len(client.ts().range(1, 0, 500)) + # last sample isn't returned + assert 20 == len( + client.ts().range( + 1, 0, 500, - filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) - - # test withlabels - assert {} == res[0]["1"][0] - res = client.ts().mrange(0, 200, filters=["Test=This"], - with_labels=True) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] - - @pytest.mark.redismod - @skip_ifmodversion_lt("99.99.99", "timeseries") - def testMultiRangeAdvanced(self, client): - client.ts().create(1, labels={"Test": "This", "team": "ny"}) - client.ts().create( - 2, - labels={"Test": "This", "Taste": "That", "team": "sf"} - ) - for i in range(100): - client.ts().add(1, i, i % 7) - client.ts().add(2, i, i % 11) + ) + assert 10 == len(client.ts().range(1, 0, 500, count=10)) - # test with selected labels - res = client.ts().mrange( - 0, - 200, - filters=["Test=This"], - select_labels=["team"] - ) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] - # test with filterby - res = client.ts().mrange( +@pytest.mark.redismod +@skip_ifmodversion_lt("99.99.99", "timeseries") +def testRangeAdvanced(client): + for i in range(100): + client.ts().add(1, i, i % 7) + client.ts().add(1, i + 200, i % 7) + + assert 2 == len( + client.ts().range( + 1, 0, - 200, - filters=["Test=This"], + 500, filter_by_ts=[i for i in range(10, 20)], filter_by_min_value=1, filter_by_max_value=2, ) - assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] - - # test groupby - res = client.ts().mrange( - 0, - 3, - filters=["Test=This"], - groupby="Test", - reduce="sum" - ) - assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][ - 1] - res = client.ts().mrange( - 0, - 3, - filters=["Test=This"], - groupby="Test", - reduce="max" - ) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][ - 1] - res = client.ts().mrange( - 0, - 3, - filters=["Test=This"], - groupby="team", - reduce="min") - assert 2 == len(res) - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] - assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] - - # test align - res = client.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] - res = client.ts().mrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=5, - ) - assert [(-5, 5.0), (5, 6.0)] == res[0]["1"][1] - - @pytest.mark.redismod - @skip_ifmodversion_lt("99.99.99", "timeseries") - def testMultiReverseRange(self, client): - client.ts().create(1, labels={"Test": "This", "team": "ny"}) - client.ts().create( - 2, - labels={"Test": "This", "Taste": "That", "team": "sf"} - ) - for i in range(100): - client.ts().add(1, i, i % 7) - client.ts().add(2, i, i % 11) - - res = client.ts().mrange(0, 200, filters=["Test=This"]) - assert 2 == len(res) - assert 100 == len(res[0]["1"][1]) + ) + assert [(0, 10.0), (10, 1.0)] == client.ts().range( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + ) + assert [(-5, 5.0), (5, 6.0)] == client.ts().range( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=5 + ) - res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) - assert 10 == len(res[0]["1"][1]) - for i in range(100): - client.ts().add(1, i + 200, i % 7) - res = client.ts().mrevrange( +@pytest.mark.redismod +@skip_ifmodversion_lt("99.99.99", "timeseries") +def testRevRange(client): + for i in range(100): + client.ts().add(1, i, i % 7) + assert 100 == len(client.ts().range(1, 0, 200)) + for i in range(100): + client.ts().add(1, i + 200, i % 7) + assert 200 == len(client.ts().range(1, 0, 500)) + # first sample isn't returned + assert 20 == len( + client.ts().revrange( + 1, 0, 500, - filters=["Test=This"], aggregation_type="avg", bucket_size_msec=10 ) - assert 2 == len(res) - assert 20 == len(res[0]["1"][1]) - assert {} == res[0]["1"][0] - - # test withlabels - res = client.ts().mrevrange( - 0, - 200, - filters=["Test=This"], - with_labels=True - ) - assert {"Test": "This", "team": "ny"} == res[0]["1"][0] - - # test with selected labels - res = client.ts().mrevrange( - 0, - 200, - filters=["Test=This"], select_labels=["team"] - ) - assert {"team": "ny"} == res[0]["1"][0] - assert {"team": "sf"} == res[1]["2"][0] - - # test filterby - res = client.ts().mrevrange( + ) + assert 10 == len(client.ts().revrange(1, 0, 500, count=10)) + assert 2 == len( + client.ts().revrange( + 1, 0, - 200, - filters=["Test=This"], + 500, filter_by_ts=[i for i in range(10, 20)], filter_by_min_value=1, filter_by_max_value=2, ) - assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + ) + assert [(10, 1.0), (0, 10.0)] == client.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align="+" + ) + assert [(1, 10.0), (-9, 1.0)] == client.ts().revrange( + 1, 0, 10, aggregation_type="count", bucket_size_msec=10, align=1 + ) - # test groupby - res = client.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" - ) - assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][ - 1] - res = client.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="Test", reduce="max" - ) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][ - 1] - res = client.ts().mrevrange( - 0, 3, filters=["Test=This"], groupby="team", reduce="min" - ) - assert 2 == len(res) - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] - assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] - # test align - res = client.ts().mrevrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align="-", - ) - assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] - res = client.ts().mrevrange( - 0, - 10, - filters=["team=ny"], - aggregation_type="count", - bucket_size_msec=10, - align=1, - ) - assert [(1, 10.0), (-9, 1.0)] == res[0]["1"][1] - - @pytest.mark.redismod - def testGet(self, client): - name = "test" - client.ts().create(name) - assert client.ts().get(name) is None - client.ts().add(name, 2, 3) - assert 2 == client.ts().get(name)[0] - client.ts().add(name, 3, 4) - assert 4 == client.ts().get(name)[1] - - @pytest.mark.redismod - def testMGet(self, client): - client.ts().create(1, labels={"Test": "This"}) - client.ts().create(2, labels={"Test": "This", "Taste": "That"}) - act_res = client.ts().mget(["Test=This"]) - exp_res = [{"1": [{}, None, None]}, {"2": [{}, None, None]}] - assert act_res == exp_res - client.ts().add(1, "*", 15) - client.ts().add(2, "*", 25) - res = client.ts().mget(["Test=This"]) - assert 15 == res[0]["1"][2] - assert 25 == res[1]["2"][2] - res = client.ts().mget(["Taste=That"]) - assert 25 == res[0]["2"][2] - - # test with_labels - assert {} == res[0]["2"][0] - res = client.ts().mget(["Taste=That"], with_labels=True) - assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] - - @pytest.mark.redismod - def testInfo(self, client): - client.ts().create( - 1, - retention_msecs=5, - labels={"currentLabel": "currentData"} - ) - info = client.ts().info(1) - assert 5 == info.retention_msecs - assert info.labels["currentLabel"] == "currentData" - - @pytest.mark.redismod - @skip_ifmodversion_lt("1.4.0", "timeseries") - def testInfoDuplicatePolicy(self, client): - client.ts().create( - 1, - retention_msecs=5, - labels={"currentLabel": "currentData"} - ) - info = client.ts().info(1) - assert info.duplicate_policy is None - - client.ts().create("time-serie-2", duplicate_policy="min") - info = client.ts().info("time-serie-2") - assert "min" == info.duplicate_policy - - @pytest.mark.redismod - def testQueryIndex(self, client): - client.ts().create(1, labels={"Test": "This"}) - client.ts().create(2, labels={"Test": "This", "Taste": "That"}) - assert 2 == len(client.ts().queryindex(["Test=This"])) - assert 1 == len(client.ts().queryindex(["Taste=That"])) - assert [2] == client.ts().queryindex(["Taste=That"]) - - # - # @pytest.mark.redismod - # @pytest.mark.pipeline - # def testPipeline(client): - # pipeline = client.ts().pipeline() - # pipeline.create("with_pipeline") - # for i in range(100): - # pipeline.add("with_pipeline", i, 1.1 * i) - # pipeline.execute() - - # info = client.ts().info("with_pipeline") - # assert info.lastTimeStamp == 99 - # assert info.total_samples == 100 - # assert client.ts().get("with_pipeline")[1] == 99 * 1.1 - - @pytest.mark.redismod - def testUncompressed(self, client): - client.ts().create("compressed") - client.ts().create("uncompressed", uncompressed=True) - compressed_info = client.ts().info("compressed") - uncompressed_info = client.ts().info("uncompressed") - assert compressed_info.memory_usage != uncompressed_info.memory_usage +@pytest.mark.redismod +def testMultiRange(client): + client.ts().create(1, labels={"Test": "This", "team": "ny"}) + client.ts().create( + 2, + labels={"Test": "This", "Taste": "That", "team": "sf"} + ) + for i in range(100): + client.ts().add(1, i, i % 7) + client.ts().add(2, i, i % 11) + + res = client.ts().mrange(0, 200, filters=["Test=This"]) + assert 2 == len(res) + assert 100 == len(res[0]["1"][1]) + + res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) + + for i in range(100): + client.ts().add(1, i + 200, i % 7) + res = client.ts().mrange( + 0, + 500, + filters=["Test=This"], + aggregation_type="avg", + bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + + # test withlabels + assert {} == res[0]["1"][0] + res = client.ts().mrange(0, 200, filters=["Test=This"], with_labels=True) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("99.99.99", "timeseries") +def testMultiRangeAdvanced(client): + client.ts().create(1, labels={"Test": "This", "team": "ny"}) + client.ts().create( + 2, + labels={"Test": "This", "Taste": "That", "team": "sf"} + ) + for i in range(100): + client.ts().add(1, i, i % 7) + client.ts().add(2, i, i % 11) + + # test with selected labels + res = client.ts().mrange( + 0, + 200, + filters=["Test=This"], + select_labels=["team"] + ) + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] + + # test with filterby + res = client.ts().mrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [(15, 1.0), (16, 2.0)] == res[0]["1"][1] + + # test groupby + res = client.ts().mrange( + 0, + 3, + filters=["Test=This"], + groupby="Test", + reduce="sum" + ) + assert [(0, 0.0), (1, 2.0), (2, 4.0), (3, 6.0)] == res[0]["Test=This"][1] + res = client.ts().mrange( + 0, + 3, + filters=["Test=This"], + groupby="Test", + reduce="max" + ) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["Test=This"][1] + res = client.ts().mrange( + 0, + 3, + filters=["Test=This"], + groupby="team", + reduce="min") + assert 2 == len(res) + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[0]["team=ny"][1] + assert [(0, 0.0), (1, 1.0), (2, 2.0), (3, 3.0)] == res[1]["team=sf"][1] + + # test align + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(0, 10.0), (10, 1.0)] == res[0]["1"][1] + res = client.ts().mrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=5, + ) + assert [(-5, 5.0), (5, 6.0)] == res[0]["1"][1] + + +@pytest.mark.redismod +@skip_ifmodversion_lt("99.99.99", "timeseries") +def testMultiReverseRange(client): + client.ts().create(1, labels={"Test": "This", "team": "ny"}) + client.ts().create( + 2, + labels={"Test": "This", "Taste": "That", "team": "sf"} + ) + for i in range(100): + client.ts().add(1, i, i % 7) + client.ts().add(2, i, i % 11) + + res = client.ts().mrange(0, 200, filters=["Test=This"]) + assert 2 == len(res) + assert 100 == len(res[0]["1"][1]) + + res = client.ts().mrange(0, 200, filters=["Test=This"], count=10) + assert 10 == len(res[0]["1"][1]) + + for i in range(100): + client.ts().add(1, i + 200, i % 7) + res = client.ts().mrevrange( + 0, + 500, + filters=["Test=This"], + aggregation_type="avg", + bucket_size_msec=10 + ) + assert 2 == len(res) + assert 20 == len(res[0]["1"][1]) + assert {} == res[0]["1"][0] + + # test withlabels + res = client.ts().mrevrange( + 0, + 200, + filters=["Test=This"], + with_labels=True + ) + assert {"Test": "This", "team": "ny"} == res[0]["1"][0] + + # test with selected labels + res = client.ts().mrevrange( + 0, + 200, + filters=["Test=This"], select_labels=["team"] + ) + assert {"team": "ny"} == res[0]["1"][0] + assert {"team": "sf"} == res[1]["2"][0] + + # test filterby + res = client.ts().mrevrange( + 0, + 200, + filters=["Test=This"], + filter_by_ts=[i for i in range(10, 20)], + filter_by_min_value=1, + filter_by_max_value=2, + ) + assert [(16, 2.0), (15, 1.0)] == res[0]["1"][1] + + # test groupby + res = client.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="sum" + ) + assert [(3, 6.0), (2, 4.0), (1, 2.0), (0, 0.0)] == res[0]["Test=This"][1] + res = client.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="Test", reduce="max" + ) + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["Test=This"][1] + res = client.ts().mrevrange( + 0, 3, filters=["Test=This"], groupby="team", reduce="min" + ) + assert 2 == len(res) + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[0]["team=ny"][1] + assert [(3, 3.0), (2, 2.0), (1, 1.0), (0, 0.0)] == res[1]["team=sf"][1] + + # test align + res = client.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align="-", + ) + assert [(10, 1.0), (0, 10.0)] == res[0]["1"][1] + res = client.ts().mrevrange( + 0, + 10, + filters=["team=ny"], + aggregation_type="count", + bucket_size_msec=10, + align=1, + ) + assert [(1, 10.0), (-9, 1.0)] == res[0]["1"][1] + + +@pytest.mark.redismod +def testGet(client): + name = "test" + client.ts().create(name) + assert client.ts().get(name) is None + client.ts().add(name, 2, 3) + assert 2 == client.ts().get(name)[0] + client.ts().add(name, 3, 4) + assert 4 == client.ts().get(name)[1] + + +@pytest.mark.redismod +def testMGet(client): + client.ts().create(1, labels={"Test": "This"}) + client.ts().create(2, labels={"Test": "This", "Taste": "That"}) + act_res = client.ts().mget(["Test=This"]) + exp_res = [{"1": [{}, None, None]}, {"2": [{}, None, None]}] + assert act_res == exp_res + client.ts().add(1, "*", 15) + client.ts().add(2, "*", 25) + res = client.ts().mget(["Test=This"]) + assert 15 == res[0]["1"][2] + assert 25 == res[1]["2"][2] + res = client.ts().mget(["Taste=That"]) + assert 25 == res[0]["2"][2] + + # test with_labels + assert {} == res[0]["2"][0] + res = client.ts().mget(["Taste=That"], with_labels=True) + assert {"Taste": "That", "Test": "This"} == res[0]["2"][0] + + +@pytest.mark.redismod +def testInfo(client): + client.ts().create( + 1, + retention_msecs=5, + labels={"currentLabel": "currentData"} + ) + info = client.ts().info(1) + assert 5 == info.retention_msecs + assert info.labels["currentLabel"] == "currentData" + + +@pytest.mark.redismod +@skip_ifmodversion_lt("1.4.0", "timeseries") +def testInfoDuplicatePolicy(client): + client.ts().create( + 1, + retention_msecs=5, + labels={"currentLabel": "currentData"} + ) + info = client.ts().info(1) + assert info.duplicate_policy is None + + client.ts().create("time-serie-2", duplicate_policy="min") + info = client.ts().info("time-serie-2") + assert "min" == info.duplicate_policy + + +@pytest.mark.redismod +def testQueryIndex(client): + client.ts().create(1, labels={"Test": "This"}) + client.ts().create(2, labels={"Test": "This", "Taste": "That"}) + assert 2 == len(client.ts().queryindex(["Test=This"])) + assert 1 == len(client.ts().queryindex(["Taste=That"])) + assert [2] == client.ts().queryindex(["Taste=That"]) + + +# +# @pytest.mark.redismod +# @pytest.mark.pipeline +# def testPipeline(client): +# pipeline = client.ts().pipeline() +# pipeline.create("with_pipeline") +# for i in range(100): +# pipeline.add("with_pipeline", i, 1.1 * i) +# pipeline.execute() + +# info = client.ts().info("with_pipeline") +# assert info.lastTimeStamp == 99 +# assert info.total_samples == 100 +# assert client.ts().get("with_pipeline")[1] == 99 * 1.1 + + +@pytest.mark.redismod +def testUncompressed(client): + client.ts().create("compressed") + client.ts().create("uncompressed", uncompressed=True) + compressed_info = client.ts().info("compressed") + uncompressed_info = client.ts().info("uncompressed") + assert compressed_info.memory_usage != uncompressed_info.memory_usage From 6a71af82de96e904a2efb7c2489f5d2acbcfaa76 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Sun, 7 Nov 2021 20:24:44 +0200 Subject: [PATCH 11/22] Added/marked command tests for cluster mode --- redis/cluster.py | 20 +- redis/commands/cluster.py | 13 +- redis/commands/parser.py | 45 ++-- tests/test_cluster.py | 473 +++++++++++++++++++++++++++++++++++++- tests/test_commands.py | 138 ++++++++++- 5 files changed, 650 insertions(+), 39 deletions(-) diff --git a/redis/cluster.py b/redis/cluster.py index bb2dbba467..df9abcbd3a 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -49,6 +49,13 @@ def get_connection(redis_node, *args, **options): ) +def parse_scan_result(command, res, **options): + keys_list = [] + for primary_res in res.values(): + keys_list += primary_res[1] + return 0, keys_list + + def parse_pubsub_numsub(command, res, **options): numsub_d = OrderedDict() for numsub_tups in res.values(): @@ -211,7 +218,6 @@ class RedisCluster(ClusterCommands, object): "CLIENT LIST", "CLIENT SETNAME", "CLIENT GETNAME", - "CONFIG GET", "CONFIG SET", "CONFIG REWRITE", "CONFIG RESETSTAT", @@ -237,7 +243,6 @@ class RedisCluster(ClusterCommands, object): "SLOWLOG LEN", "SLOWLOG RESET", "WAIT", - "TIME", "SAVE", "MEMORY PURGE", "MEMORY MALLOC-STATS", @@ -270,10 +275,14 @@ class RedisCluster(ClusterCommands, object): "CLUSTER SLOTS", "CLUSTER COUNT-FAILURE-REPORTS", "CLUSTER KEYSLOT", - "RANDOMKEY", "COMMAND", + "COMMAND COUNT", "COMMAND GETKEYS", + "CONFIG GET", "DEBUG", + "RANDOMKEY", + "STRALGO", + "TIME", ], RANDOM, ), @@ -341,7 +350,10 @@ class RedisCluster(ClusterCommands, object): else res), list_keys_to_dict([ "CLIENT UNBLOCK", - ], lambda command, res: 1 if sum(res.values()) > 0 else 0) + ], lambda command, res: 1 if sum(res.values()) > 0 else 0), + list_keys_to_dict([ + "SCAN", + ], parse_scan_result) ) def __init__( diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index d3c2e8572b..6358546802 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -321,8 +321,18 @@ def client_unpause(self, target_nodes=None): return self.execute_command('CLIENT UNPAUSE', target_nodes=target_nodes) + def command_count(self): + """ + Returns Integer reply of number of total commands in this Redis server. + Send to a random node. + """ + return self.execute_command('COMMAND COUNT') + def config_get(self, pattern="*", target_nodes=None): - """Return a dictionary of configuration based on the ``pattern``""" + """ + Return a dictionary of configuration based on the ``pattern`` + If no target nodes are specified, send to a random node + """ return self.execute_command('CONFIG GET', pattern, target_nodes=target_nodes) @@ -573,6 +583,7 @@ def time(self, target_nodes=None): """ Returns the server time as a 2-item tuple of ints: (seconds since epoch, microseconds into this second). + If target_nodes are not specified, send to a random node """ return self.execute_command('TIME', target_nodes=target_nodes) diff --git a/redis/commands/parser.py b/redis/commands/parser.py index 7a8004a913..d8b03271db 100644 --- a/redis/commands/parser.py +++ b/redis/commands/parser.py @@ -33,18 +33,21 @@ def get_keys(self, redis_conn, *args): return None cmd_name = args[0].lower() - cmd_name_split = cmd_name.split() - if len(cmd_name_split) > 1: - # we need to take only the main command, e.g. 'memory' for - # 'memory usage' - cmd_name = cmd_name_split[0] if cmd_name not in self.commands: - # We'll try to reinitialize the commands cache, if the engine - # version has changed, the commands may not be current - self.initialize(redis_conn) - if cmd_name not in self.commands: - raise RedisError("{0} command doesn't exist in Redis commands". - format(cmd_name.upper())) + # try to split the command name and to take only the main command, + # e.g. 'memory' for 'memory usage' + cmd_name_split = cmd_name.split() + cmd_name = cmd_name_split[0] + if cmd_name in self.commands: + # save the splitted command to args + args = cmd_name_split + list(args[1:]) + else: + # We'll try to reinitialize the commands cache, if the engine + # version has changed, the commands may not be current + self.initialize(redis_conn) + if cmd_name not in self.commands: + raise RedisError("{0} command doesn't exist in Redis " + "commands".format(cmd_name.upper())) command = self.commands.get(cmd_name) if 'movablekeys' in command['flags']: @@ -57,8 +60,8 @@ def get_keys(self, redis_conn, *args): # The command doesn't have keys in it return None last_key_pos = command['last_key_pos'] - if last_key_pos == -1: - last_key_pos = len(args) - 1 + if last_key_pos < 0: + last_key_pos = len(args) - abs(last_key_pos) keys_pos = list(range(command['first_key_pos'], last_key_pos + 1, command['step_count'])) keys = [args[pos] for pos in keys_pos] @@ -95,13 +98,21 @@ def _get_pubsub_keys(self, *args): return None args = [str_if_bytes(arg) for arg in args] command = args[0].upper() - if command in ['PUBLISH', 'PUBSUB CHANNELS']: + if command == 'PUBSUB': + # the second argument is a part of the command name, e.g. + # ['PUBSUB', 'NUMSUB', 'foo']. + pubsub_type = args[1].upper() + if pubsub_type in ['CHANNELS', 'NUMSUB']: + keys = args[2:] + elif command in ['SUBSCRIBE', 'PSUBSCRIBE', 'UNSUBSCRIBE', + 'PUNSUBSCRIBE']: + # format example: + # SUBSCRIBE channel [channel ...] + keys = list(args[1:]) + elif command == 'PUBLISH': # format example: # PUBLISH channel message keys = [args[1]] - elif command in ['SUBSCRIBE', 'PSUBSCRIBE', 'UNSUBSCRIBE', - 'PUNSUBSCRIBE', 'PUBSUB NUMSUB']: - keys = list(args[1:]) else: keys = None return keys diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 5027759ea3..4dc6e10491 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -1,5 +1,6 @@ -import pytest +import binascii import datetime +import pytest import warnings from time import sleep @@ -13,6 +14,7 @@ from redis.exceptions import ( AskError, ClusterDownError, + DataError, MovedError, RedisClusterException, RedisError @@ -21,7 +23,8 @@ from redis.crc import key_slot from .conftest import ( _get_client, - skip_if_server_version_lt + skip_if_server_version_lt, + skip_unless_arch_bits ) default_host = "127.0.0.1" @@ -58,6 +61,7 @@ def slowlog(request, r): def cleanup(): r.config_set('slowlog-log-slower-than', old_slower_than_value) r.config_set('slowlog-max-len', old_max_legnth_value) + request.addfinalizer(cleanup) # Set the new values @@ -601,6 +605,19 @@ def test_dbsize(self, r): def test_config_set(self, r): assert r.config_set('slowlog-log-slower-than', 0) + def test_cluster_config_resetstat(self, r): + r.ping() + all_info = r.info() + prior_commands_processed = -1 + for node_info in all_info.values(): + prior_commands_processed = node_info['total_commands_processed'] + assert prior_commands_processed >= 1 + r.config_resetstat() + all_info = r.info() + for node_info in all_info.values(): + reset_commands_processed = node_info['total_commands_processed'] + assert reset_commands_processed < prior_commands_processed + def test_client_setname(self, r): r.client_setname('redis_py_test') res = r.client_getname() @@ -983,19 +1000,12 @@ def test_memory_doctor(self, r): with pytest.raises(NotImplementedError): r.memory_doctor() - def test_object(self, r): - r['a'] = 'foo' - assert isinstance(r.object('refcount', 'a'), int) - assert isinstance(r.object('idletime', 'a'), int) - assert r.object('encoding', 'a') in (b'raw', b'embstr') - assert r.object('idletime', 'invalid-key') is None - def test_lastsave(self, r): node = r.get_primaries()[0] assert isinstance(r.lastsave(target_nodes=node), datetime.datetime) - def test_echo(self, r): + def test_cluster_echo(self, r): node = r.get_primaries()[0] assert r.echo('foo bar', node) == b'foo bar' @@ -1079,6 +1089,449 @@ def test_client_kill(self, r, r2): assert len(clients) == 1 assert clients[0].get('name') == 'redis-py-c1' + @skip_if_server_version_lt('2.6.0') + def test_cluster_bitop_not_empty_string(self, r): + r['{foo}a'] = '' + r.bitop('not', '{foo}r', '{foo}a') + assert r.get('{foo}r') is None + + @skip_if_server_version_lt('2.6.0') + def test_cluster_bitop_not(self, r): + test_str = b'\xAA\x00\xFF\x55' + correct = ~0xAA00FF55 & 0xFFFFFFFF + r['{foo}a'] = test_str + r.bitop('not', '{foo}r', '{foo}a') + assert int(binascii.hexlify(r['{foo}r']), 16) == correct + + @skip_if_server_version_lt('2.6.0') + def test_cluster_bitop_not_in_place(self, r): + test_str = b'\xAA\x00\xFF\x55' + correct = ~0xAA00FF55 & 0xFFFFFFFF + r['{foo}a'] = test_str + r.bitop('not', '{foo}a', '{foo}a') + assert int(binascii.hexlify(r['{foo}a']), 16) == correct + + @skip_if_server_version_lt('2.6.0') + def test_cluster_bitop_single_string(self, r): + test_str = b'\x01\x02\xFF' + r['{foo}a'] = test_str + r.bitop('and', '{foo}res1', '{foo}a') + r.bitop('or', '{foo}res2', '{foo}a') + r.bitop('xor', '{foo}res3', '{foo}a') + assert r['{foo}res1'] == test_str + assert r['{foo}res2'] == test_str + assert r['{foo}res3'] == test_str + + @skip_if_server_version_lt('2.6.0') + def test_cluster_bitop_string_operands(self, r): + r['{foo}a'] = b'\x01\x02\xFF\xFF' + r['{foo}b'] = b'\x01\x02\xFF' + r.bitop('and', '{foo}res1', '{foo}a', '{foo}b') + r.bitop('or', '{foo}res2', '{foo}a', '{foo}b') + r.bitop('xor', '{foo}res3', '{foo}a', '{foo}b') + assert int(binascii.hexlify(r['{foo}res1']), 16) == 0x0102FF00 + assert int(binascii.hexlify(r['{foo}res2']), 16) == 0x0102FFFF + assert int(binascii.hexlify(r['{foo}res3']), 16) == 0x000000FF + + @skip_if_server_version_lt('6.2.0') + def test_cluster_copy(self, r): + assert r.copy("{foo}a", "{foo}b") == 0 + r.set("{foo}a", "bar") + assert r.copy("{foo}a", "{foo}b") == 1 + assert r.get("{foo}a") == b"bar" + assert r.get("{foo}b") == b"bar" + + @skip_if_server_version_lt('6.2.0') + def test_cluster_copy_and_replace(self, r): + r.set("{foo}a", "foo1") + r.set("{foo}b", "foo2") + assert r.copy("{foo}a", "{foo}b") == 0 + assert r.copy("{foo}a", "{foo}b", replace=True) == 1 + + @skip_if_server_version_lt('6.2.0') + def test_cluster_lmove(self, r): + r.rpush('{foo}a', 'one', 'two', 'three', 'four') + assert r.lmove('{foo}a', '{foo}b') + assert r.lmove('{foo}a', '{foo}b', 'right', 'left') + + @skip_if_server_version_lt('6.2.0') + def test_cluster_blmove(self, r): + r.rpush('{foo}a', 'one', 'two', 'three', 'four') + assert r.blmove('{foo}a', '{foo}b', 5) + assert r.blmove('{foo}a', '{foo}b', 1, 'RIGHT', 'LEFT') + + def test_cluster_msetnx(self, r): + d = {'{foo}a': b'1', '{foo}b': b'2', '{foo}c': b'3'} + assert r.msetnx(d) + d2 = {'{foo}a': b'x', '{foo}d': b'4'} + assert not r.msetnx(d2) + for k, v in d.items(): + assert r[k] == v + assert r.get('{foo}d') is None + + def test_cluster_rename(self, r): + r['{foo}a'] = '1' + assert r.rename('{foo}a', '{foo}b') + assert r.get('{foo}a') is None + assert r['{foo}b'] == b'1' + + def test_cluster_renamenx(self, r): + r['{foo}a'] = '1' + r['{foo}b'] = '2' + assert not r.renamenx('{foo}a', '{foo}b') + assert r['{foo}a'] == b'1' + assert r['{foo}b'] == b'2' + + # LIST COMMANDS + def test_cluster_blpop(self, r): + r.rpush('{foo}a', '1', '2') + r.rpush('{foo}b', '3', '4') + assert r.blpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}b', b'3') + assert r.blpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}b', b'4') + assert r.blpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}a', b'1') + assert r.blpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}a', b'2') + assert r.blpop(['{foo}b', '{foo}a'], timeout=1) is None + r.rpush('{foo}c', '1') + assert r.blpop('{foo}c', timeout=1) == (b'{foo}c', b'1') + + def test_cluster_brpop(self, r): + r.rpush('{foo}a', '1', '2') + r.rpush('{foo}b', '3', '4') + assert r.brpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}b', b'4') + assert r.brpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}b', b'3') + assert r.brpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}a', b'2') + assert r.brpop(['{foo}b', '{foo}a'], timeout=1) == (b'{foo}a', b'1') + assert r.brpop(['{foo}b', '{foo}a'], timeout=1) is None + r.rpush('{foo}c', '1') + assert r.brpop('{foo}c', timeout=1) == (b'{foo}c', b'1') + + def test_cluster_brpoplpush(self, r): + r.rpush('{foo}a', '1', '2') + r.rpush('{foo}b', '3', '4') + assert r.brpoplpush('{foo}a', '{foo}b') == b'2' + assert r.brpoplpush('{foo}a', '{foo}b') == b'1' + assert r.brpoplpush('{foo}a', '{foo}b', timeout=1) is None + assert r.lrange('{foo}a', 0, -1) == [] + assert r.lrange('{foo}b', 0, -1) == [b'1', b'2', b'3', b'4'] + + def test_cluster_brpoplpush_empty_string(self, r): + r.rpush('{foo}a', '') + assert r.brpoplpush('{foo}a', '{foo}b') == b'' + + def test_cluster_rpoplpush(self, r): + r.rpush('{foo}a', 'a1', 'a2', 'a3') + r.rpush('{foo}b', 'b1', 'b2', 'b3') + assert r.rpoplpush('{foo}a', '{foo}b') == b'a3' + assert r.lrange('{foo}a', 0, -1) == [b'a1', b'a2'] + assert r.lrange('{foo}b', 0, -1) == [b'a3', b'b1', b'b2', b'b3'] + + def test_cluster_sdiff(self, r): + r.sadd('{foo}a', '1', '2', '3') + assert r.sdiff('{foo}a', '{foo}b') == {b'1', b'2', b'3'} + r.sadd('{foo}b', '2', '3') + assert r.sdiff('{foo}a', '{foo}b') == {b'1'} + + def test_cluster_sdiffstore(self, r): + r.sadd('{foo}a', '1', '2', '3') + assert r.sdiffstore('{foo}c', '{foo}a', '{foo}b') == 3 + assert r.smembers('{foo}c') == {b'1', b'2', b'3'} + r.sadd('{foo}b', '2', '3') + assert r.sdiffstore('{foo}c', '{foo}a', '{foo}b') == 1 + assert r.smembers('{foo}c') == {b'1'} + + def test_cluster_sinter(self, r): + r.sadd('{foo}a', '1', '2', '3') + assert r.sinter('{foo}a', '{foo}b') == set() + r.sadd('{foo}b', '2', '3') + assert r.sinter('{foo}a', '{foo}b') == {b'2', b'3'} + + def test_cluster_sinterstore(self, r): + r.sadd('{foo}a', '1', '2', '3') + assert r.sinterstore('{foo}c', '{foo}a', '{foo}b') == 0 + assert r.smembers('{foo}c') == set() + r.sadd('{foo}b', '2', '3') + assert r.sinterstore('{foo}c', '{foo}a', '{foo}b') == 2 + assert r.smembers('{foo}c') == {b'2', b'3'} + + def test_cluster_smove(self, r): + r.sadd('{foo}a', 'a1', 'a2') + r.sadd('{foo}b', 'b1', 'b2') + assert r.smove('{foo}a', '{foo}b', 'a1') + assert r.smembers('{foo}a') == {b'a2'} + assert r.smembers('{foo}b') == {b'b1', b'b2', b'a1'} + + def test_cluster_sunion(self, r): + r.sadd('{foo}a', '1', '2') + r.sadd('{foo}b', '2', '3') + assert r.sunion('{foo}a', '{foo}b') == {b'1', b'2', b'3'} + + def test_cluster_sunionstore(self, r): + r.sadd('{foo}a', '1', '2') + r.sadd('{foo}b', '2', '3') + assert r.sunionstore('{foo}c', '{foo}a', '{foo}b') == 3 + assert r.smembers('{foo}c') == {b'1', b'2', b'3'} + + @skip_if_server_version_lt('6.2.0') + def test_cluster_zdiff(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 2, 'a3': 3}) + r.zadd('{foo}b', {'a1': 1, 'a2': 2}) + assert r.zdiff(['{foo}a', '{foo}b']) == [b'a3'] + assert r.zdiff(['{foo}a', '{foo}b'], withscores=True) == [b'a3', b'3'] + + @skip_if_server_version_lt('6.2.0') + def test_cluster_zdiffstore(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 2, 'a3': 3}) + r.zadd('{foo}b', {'a1': 1, 'a2': 2}) + assert r.zdiffstore("{foo}out", ['{foo}a', '{foo}b']) + assert r.zrange("{foo}out", 0, -1) == [b'a3'] + assert r.zrange("{foo}out", 0, -1, withscores=True) == [(b'a3', 3.0)] + + @skip_if_server_version_lt('6.2.0') + def test_cluster_zinter(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 2, 'a3': 1}) + r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) + assert r.zinter(['{foo}a', '{foo}b', '{foo}c']) == [b'a3', b'a1'] + # invalid aggregation + with pytest.raises(DataError): + r.zinter(['{foo}a', '{foo}b', '{foo}c'], + aggregate='foo', withscores=True) + # aggregate with SUM + assert r.zinter(['{foo}a', '{foo}b', '{foo}c'], withscores=True) \ + == [(b'a3', 8), (b'a1', 9)] + # aggregate with MAX + assert r.zinter(['{foo}a', '{foo}b', '{foo}c'], aggregate='MAX', + withscores=True) \ + == [(b'a3', 5), (b'a1', 6)] + # aggregate with MIN + assert r.zinter(['{foo}a', '{foo}b', '{foo}c'], aggregate='MIN', + withscores=True) \ + == [(b'a1', 1), (b'a3', 1)] + # with weights + assert r.zinter({'{foo}a': 1, '{foo}b': 2, '{foo}c': 3}, + withscores=True) \ + == [(b'a3', 20), (b'a1', 23)] + + def test_cluster_zinterstore_sum(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) + assert r.zinterstore('{foo}d', ['{foo}a', '{foo}b', '{foo}c']) == 2 + assert r.zrange('{foo}d', 0, -1, withscores=True) == \ + [(b'a3', 8), (b'a1', 9)] + + def test_cluster_zinterstore_max(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) + assert r.zinterstore( + '{foo}d', ['{foo}a', '{foo}b', '{foo}c'], aggregate='MAX') == 2 + assert r.zrange('{foo}d', 0, -1, withscores=True) == \ + [(b'a3', 5), (b'a1', 6)] + + def test_cluster_zinterstore_min(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 2, 'a3': 3}) + r.zadd('{foo}b', {'a1': 2, 'a2': 3, 'a3': 5}) + r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) + assert r.zinterstore( + '{foo}d', ['{foo}a', '{foo}b', '{foo}c'], aggregate='MIN') == 2 + assert r.zrange('{foo}d', 0, -1, withscores=True) == \ + [(b'a1', 1), (b'a3', 3)] + + def test_cluster_zinterstore_with_weight(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) + assert r.zinterstore( + '{foo}d', {'{foo}a': 1, '{foo}b': 2, '{foo}c': 3}) == 2 + assert r.zrange('{foo}d', 0, -1, withscores=True) == \ + [(b'a3', 20), (b'a1', 23)] + + @skip_if_server_version_lt('4.9.0') + def test_cluster_bzpopmax(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 2}) + r.zadd('{foo}b', {'b1': 10, 'b2': 20}) + assert r.bzpopmax(['{foo}b', '{foo}a'], timeout=1) == ( + b'{foo}b', b'b2', 20) + assert r.bzpopmax(['{foo}b', '{foo}a'], timeout=1) == ( + b'{foo}b', b'b1', 10) + assert r.bzpopmax(['{foo}b', '{foo}a'], timeout=1) == ( + b'{foo}a', b'a2', 2) + assert r.bzpopmax(['{foo}b', '{foo}a'], timeout=1) == ( + b'{foo}a', b'a1', 1) + assert r.bzpopmax(['{foo}b', '{foo}a'], timeout=1) is None + r.zadd('{foo}c', {'c1': 100}) + assert r.bzpopmax('{foo}c', timeout=1) == (b'{foo}c', b'c1', 100) + + @skip_if_server_version_lt('4.9.0') + def test_cluster_bzpopmin(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 2}) + r.zadd('{foo}b', {'b1': 10, 'b2': 20}) + assert r.bzpopmin(['{foo}b', '{foo}a'], timeout=1) == ( + b'{foo}b', b'b1', 10) + assert r.bzpopmin(['{foo}b', '{foo}a'], timeout=1) == ( + b'{foo}b', b'b2', 20) + assert r.bzpopmin(['{foo}b', '{foo}a'], timeout=1) == ( + b'{foo}a', b'a1', 1) + assert r.bzpopmin(['{foo}b', '{foo}a'], timeout=1) == ( + b'{foo}a', b'a2', 2) + assert r.bzpopmin(['{foo}b', '{foo}a'], timeout=1) is None + r.zadd('{foo}c', {'c1': 100}) + assert r.bzpopmin('{foo}c', timeout=1) == (b'{foo}c', b'c1', 100) + + @skip_if_server_version_lt('6.2.0') + def test_cluster_zrangestore(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 2, 'a3': 3}) + assert r.zrangestore('{foo}b', '{foo}a', 0, 1) + assert r.zrange('{foo}b', 0, -1) == [b'a1', b'a2'] + assert r.zrangestore('{foo}b', '{foo}a', 1, 2) + assert r.zrange('{foo}b', 0, -1) == [b'a2', b'a3'] + assert r.zrange('{foo}b', 0, -1, withscores=True) == \ + [(b'a2', 2), (b'a3', 3)] + # reversed order + assert r.zrangestore('{foo}b', '{foo}a', 1, 2, desc=True) + assert r.zrange('{foo}b', 0, -1) == [b'a1', b'a2'] + # by score + assert r.zrangestore('{foo}b', '{foo}a', 2, 1, byscore=True, + offset=0, num=1, desc=True) + assert r.zrange('{foo}b', 0, -1) == [b'a2'] + # by lex + assert r.zrangestore('{foo}b', '{foo}a', '[a2', '(a3', bylex=True, + offset=0, num=1) + assert r.zrange('{foo}b', 0, -1) == [b'a2'] + + @skip_if_server_version_lt('6.2.0') + def test_cluster_zunion(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) + # sum + assert r.zunion(['{foo}a', '{foo}b', '{foo}c']) == \ + [b'a2', b'a4', b'a3', b'a1'] + assert r.zunion(['{foo}a', '{foo}b', '{foo}c'], withscores=True) == \ + [(b'a2', 3), (b'a4', 4), (b'a3', 8), (b'a1', 9)] + # max + assert r.zunion(['{foo}a', '{foo}b', '{foo}c'], aggregate='MAX', + withscores=True) \ + == [(b'a2', 2), (b'a4', 4), (b'a3', 5), (b'a1', 6)] + # min + assert r.zunion(['{foo}a', '{foo}b', '{foo}c'], aggregate='MIN', + withscores=True) \ + == [(b'a1', 1), (b'a2', 1), (b'a3', 1), (b'a4', 4)] + # with weight + assert r.zunion({'{foo}a': 1, '{foo}b': 2, '{foo}c': 3}, + withscores=True) \ + == [(b'a2', 5), (b'a4', 12), (b'a3', 20), (b'a1', 23)] + + def test_cluster_zunionstore_sum(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) + assert r.zunionstore('{foo}d', ['{foo}a', '{foo}b', '{foo}c']) == 4 + assert r.zrange('{foo}d', 0, -1, withscores=True) == \ + [(b'a2', 3), (b'a4', 4), (b'a3', 8), (b'a1', 9)] + + def test_cluster_zunionstore_max(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) + assert r.zunionstore( + '{foo}d', ['{foo}a', '{foo}b', '{foo}c'], aggregate='MAX') == 4 + assert r.zrange('{foo}d', 0, -1, withscores=True) == \ + [(b'a2', 2), (b'a4', 4), (b'a3', 5), (b'a1', 6)] + + def test_cluster_zunionstore_min(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 2, 'a3': 3}) + r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 4}) + r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) + assert r.zunionstore( + '{foo}d', ['{foo}a', '{foo}b', '{foo}c'], aggregate='MIN') == 4 + assert r.zrange('{foo}d', 0, -1, withscores=True) == \ + [(b'a1', 1), (b'a2', 2), (b'a3', 3), (b'a4', 4)] + + def test_cluster_zunionstore_with_weight(self, r): + r.zadd('{foo}a', {'a1': 1, 'a2': 1, 'a3': 1}) + r.zadd('{foo}b', {'a1': 2, 'a2': 2, 'a3': 2}) + r.zadd('{foo}c', {'a1': 6, 'a3': 5, 'a4': 4}) + assert r.zunionstore( + '{foo}d', {'{foo}a': 1, '{foo}b': 2, '{foo}c': 3}) == 4 + assert r.zrange('{foo}d', 0, -1, withscores=True) == \ + [(b'a2', 5), (b'a4', 12), (b'a3', 20), (b'a1', 23)] + + @skip_if_server_version_lt('2.8.9') + def test_cluster_pfcount(self, r): + members = {b'1', b'2', b'3'} + r.pfadd('{foo}a', *members) + assert r.pfcount('{foo}a') == len(members) + members_b = {b'2', b'3', b'4'} + r.pfadd('{foo}b', *members_b) + assert r.pfcount('{foo}b') == len(members_b) + assert r.pfcount('{foo}a', '{foo}b') == len(members_b.union(members)) + + @skip_if_server_version_lt('2.8.9') + def test_cluster_pfmerge(self, r): + mema = {b'1', b'2', b'3'} + memb = {b'2', b'3', b'4'} + memc = {b'5', b'6', b'7'} + r.pfadd('{foo}a', *mema) + r.pfadd('{foo}b', *memb) + r.pfadd('{foo}c', *memc) + r.pfmerge('{foo}d', '{foo}c', '{foo}a') + assert r.pfcount('{foo}d') == 6 + r.pfmerge('{foo}d', '{foo}b') + assert r.pfcount('{foo}d') == 7 + + def test_cluster_sort_store(self, r): + r.rpush('{foo}a', '2', '3', '1') + assert r.sort('{foo}a', store='{foo}sorted_values') == 3 + assert r.lrange('{foo}sorted_values', 0, -1) == [b'1', b'2', b'3'] + + # GEO COMMANDS + @skip_if_server_version_lt('6.2.0') + def test_cluster_geosearchstore(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') + \ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('{foo}barcelona', values) + r.geosearchstore('{foo}places_barcelona', '{foo}barcelona', + longitude=2.191, latitude=41.433, radius=1000) + assert r.zrange('{foo}places_barcelona', 0, -1) == [b'place1'] + + @skip_unless_arch_bits(64) + @skip_if_server_version_lt('6.2.0') + def test_geosearchstore_dist(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') + \ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('{foo}barcelona', values) + r.geosearchstore('{foo}places_barcelona', '{foo}barcelona', + longitude=2.191, latitude=41.433, + radius=1000, storedist=True) + # instead of save the geo score, the distance is saved. + assert r.zscore('{foo}places_barcelona', 'place1') == 88.05060698409301 + + @skip_if_server_version_lt('3.2.0') + def test_cluster_georadius_store(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') + \ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('{foo}barcelona', values) + r.georadius('{foo}barcelona', 2.191, 41.433, + 1000, store='{foo}places_barcelona') + assert r.zrange('{foo}places_barcelona', 0, -1) == [b'place1'] + + @skip_unless_arch_bits(64) + @skip_if_server_version_lt('3.2.0') + def test_cluster_georadius_store_dist(self, r): + values = (2.1909389952632, 41.433791470673, 'place1') + \ + (2.1873744593677, 41.406342043777, 'place2') + + r.geoadd('{foo}barcelona', values) + r.georadius('{foo}barcelona', 2.191, 41.433, 1000, + store_dist='{foo}places_barcelona') + # instead of save the geo score, the distance is saved. + assert r.zscore('{foo}places_barcelona', 'place1') == 88.05060698409301 + @pytest.mark.onlycluster class TestNodesManager: diff --git a/tests/test_commands.py b/tests/test_commands.py index 6d397e610e..04861403f4 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -61,7 +61,6 @@ def test_case_insensitive_command_names(self, r): assert r.response_callbacks['del'] == r.response_callbacks['DEL'] -@pytest.mark.onlynoncluster class TestRedisCommands: def test_command_on_invalid_key_type(self, r): r.lpush('a', '1') @@ -69,18 +68,21 @@ def test_command_on_invalid_key_type(self, r): r['a'] # SERVER INFORMATION + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_cat_no_category(self, r): categories = r.acl_cat() assert isinstance(categories, list) assert 'read' in categories + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_cat_with_category(self, r): commands = r.acl_cat('read') assert isinstance(commands, list) assert 'get' in commands + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_deluser(self, r, request): username = 'redis-py-user' @@ -105,6 +107,7 @@ def teardown(): assert r.acl_getuser(users[3]) is None assert r.acl_getuser(users[4]) is None + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_genpass(self, r): password = r.acl_genpass() @@ -118,6 +121,7 @@ def test_acl_genpass(self, r): r.acl_genpass(555) assert isinstance(password, str) + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_getuser_setuser(self, r, request): username = 'redis-py-user' @@ -206,12 +210,14 @@ def teardown(): hashed_passwords=['-' + hashed_password]) assert len(r.acl_getuser(username)['passwords']) == 1 + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_help(self, r): res = r.acl_help() assert isinstance(res, list) assert len(res) != 0 + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_list(self, r, request): username = 'redis-py-user' @@ -225,6 +231,7 @@ def teardown(): users = r.acl_list() assert len(users) == 2 + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_log(self, r, request): username = 'redis-py-user' @@ -260,6 +267,7 @@ def teardown(): assert 'client-info' in r.acl_log(count=1)[0] assert r.acl_log_reset() + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_setuser_categories_without_prefix_fails(self, r, request): username = 'redis-py-user' @@ -272,6 +280,7 @@ def teardown(): with pytest.raises(exceptions.DataError): r.acl_setuser(username, categories=['list']) + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_setuser_commands_without_prefix_fails(self, r, request): username = 'redis-py-user' @@ -284,6 +293,7 @@ def teardown(): with pytest.raises(exceptions.DataError): r.acl_setuser(username, commands=['get']) + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_setuser_add_passwords_and_nopass_fails(self, r, request): username = 'redis-py-user' @@ -296,28 +306,33 @@ def teardown(): with pytest.raises(exceptions.DataError): r.acl_setuser(username, passwords='+mypass', nopass=True) + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_users(self, r): users = r.acl_users() assert isinstance(users, list) assert len(users) > 0 + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_acl_whoami(self, r): username = r.acl_whoami() assert isinstance(username, str) + @pytest.mark.onlynoncluster def test_client_list(self, r): clients = r.client_list() assert isinstance(clients[0], dict) assert 'addr' in clients[0] + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_client_info(self, r): info = r.client_info() assert isinstance(info, dict) assert 'addr' in info + @pytest.mark.onlynoncluster @skip_if_server_version_lt('5.0.0') def test_client_list_type(self, r): with pytest.raises(exceptions.RedisError): @@ -326,6 +341,7 @@ def test_client_list_type(self, r): clients = r.client_list(_type=client_type) assert isinstance(clients, list) + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_client_list_client_id(self, r, request): clients = r.client_list() @@ -340,16 +356,19 @@ def test_client_list_client_id(self, r, request): clients_listed = r.client_list(client_id=clients[:-1]) assert len(clients_listed) > 1 + @pytest.mark.onlynoncluster @skip_if_server_version_lt('5.0.0') def test_client_id(self, r): assert r.client_id() > 0 + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_client_trackinginfo(self, r): res = r.client_trackinginfo() assert len(res) > 2 assert 'prefixes' in res + @pytest.mark.onlynoncluster @skip_if_server_version_lt('5.0.0') def test_client_unblock(self, r): myid = r.client_id() @@ -357,15 +376,18 @@ def test_client_unblock(self, r): assert not r.client_unblock(myid, error=True) assert not r.client_unblock(myid, error=False) + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.6.9') def test_client_getname(self, r): assert r.client_getname() is None + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.6.9') def test_client_setname(self, r): assert r.client_setname('redis_py_test') assert r.client_getname() == 'redis_py_test' + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.6.9') def test_client_kill(self, r, r2): r.client_setname('redis-py-c1') @@ -385,6 +407,7 @@ def test_client_kill(self, r, r2): assert len(clients) == 1 assert clients[0].get('name') == 'redis-py-c1' + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.12') def test_client_kill_filter_invalid_params(self, r): # empty @@ -399,6 +422,7 @@ def test_client_kill_filter_invalid_params(self, r): with pytest.raises(exceptions.DataError): r.client_kill_filter(_type="caster") + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.12') def test_client_kill_filter_by_id(self, r, r2): r.client_setname('redis-py-c1') @@ -419,6 +443,7 @@ def test_client_kill_filter_by_id(self, r, r2): assert len(clients) == 1 assert clients[0].get('name') == 'redis-py-c1' + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.12') def test_client_kill_filter_by_addr(self, r, r2): r.client_setname('redis-py-c1') @@ -439,6 +464,7 @@ def test_client_kill_filter_by_addr(self, r, r2): assert len(clients) == 1 assert clients[0].get('name') == 'redis-py-c1' + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.6.9') def test_client_list_after_client_setname(self, r): r.client_setname('redis_py_test') @@ -446,6 +472,7 @@ def test_client_list_after_client_setname(self, r): # we don't know which client ours will be assert 'redis_py_test' in [c['name'] for c in clients] + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_client_kill_filter_by_laddr(self, r, r2): r.client_setname('redis-py-c1') @@ -460,6 +487,7 @@ def test_client_kill_filter_by_laddr(self, r, r2): client_2_addr = clients_by_name['redis-py-c2'].get('laddr') assert r.client_kill_filter(laddr=client_2_addr) + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.12') def test_client_kill_filter_by_user(self, r, request): killuser = 'user_to_kill' @@ -473,6 +501,7 @@ def test_client_kill_filter_by_user(self, r, request): assert c['user'] != killuser r.acl_deluser(killuser) + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.9.50') def test_client_pause(self, r): assert r.client_pause(1) @@ -480,10 +509,12 @@ def test_client_pause(self, r): with pytest.raises(exceptions.RedisError): r.client_pause(timeout='not an integer') + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_client_unpause(self, r): assert r.client_unpause() == b'OK' + @pytest.mark.onlynoncluster @skip_if_server_version_lt('3.2.0') def test_client_reply(self, r, r_timeout): assert r_timeout.client_reply('ON') == b'OK' @@ -497,6 +528,7 @@ def test_client_reply(self, r, r_timeout): # validate it was set assert r.get('foo') == b'bar' + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.0.0') def test_client_getredir(self, r): assert isinstance(r.client_getredir(), int) @@ -507,6 +539,7 @@ def test_config_get(self, r): assert 'maxmemory' in data assert data['maxmemory'].isdigit() + @pytest.mark.onlynoncluster def test_config_resetstat(self, r): r.ping() prior_commands_processed = int(r.info()['total_commands_processed']) @@ -529,9 +562,11 @@ def test_dbsize(self, r): r['b'] = 'bar' assert r.dbsize() == 2 + @pytest.mark.onlynoncluster def test_echo(self, r): assert r.echo('foo bar') == b'foo bar' + @pytest.mark.onlynoncluster def test_info(self, r): r['a'] = 'foo' r['b'] = 'bar' @@ -539,9 +574,11 @@ def test_info(self, r): assert isinstance(info, dict) assert info['db9']['keys'] == 2 + @pytest.mark.onlynoncluster def test_lastsave(self, r): assert isinstance(r.lastsave(), datetime.datetime) + @pytest.mark.onlynoncluster @skip_if_server_version_lt('5.0.0') def test_lolwut(self, r): lolwut = r.lolwut().decode('utf-8') @@ -560,9 +597,11 @@ def test_object(self, r): def test_ping(self, r): assert r.ping() + @pytest.mark.onlynoncluster def test_quit(self, r): assert r.quit() + @pytest.mark.onlynoncluster def test_slowlog_get(self, r, slowlog): assert r.slowlog_reset() unicode_string = chr(3456) + 'abcd' + chr(3421) @@ -614,6 +653,7 @@ def parse_response(connection, command_name, **options): # tear down monkeypatch r.parse_response = old_parse_response + @pytest.mark.onlynoncluster def test_slowlog_get_limit(self, r, slowlog): assert r.slowlog_reset() r.get('foo') @@ -622,6 +662,7 @@ def test_slowlog_get_limit(self, r, slowlog): # only one command, based on the number we passed to slowlog_get() assert len(slowlog) == 1 + @pytest.mark.onlynoncluster def test_slowlog_length(self, r, slowlog): r.get('foo') assert isinstance(r.slowlog_len(), int) @@ -664,12 +705,14 @@ def test_bitcount(self, r): assert r.bitcount('a', -2, -1) == 2 assert r.bitcount('a', 1, 1) == 1 + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.6.0') def test_bitop_not_empty_string(self, r): r['a'] = '' r.bitop('not', 'r', 'a') assert r.get('r') is None + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.6.0') def test_bitop_not(self, r): test_str = b'\xAA\x00\xFF\x55' @@ -678,6 +721,7 @@ def test_bitop_not(self, r): r.bitop('not', 'r', 'a') assert int(binascii.hexlify(r['r']), 16) == correct + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.6.0') def test_bitop_not_in_place(self, r): test_str = b'\xAA\x00\xFF\x55' @@ -686,6 +730,7 @@ def test_bitop_not_in_place(self, r): r.bitop('not', 'a', 'a') assert int(binascii.hexlify(r['a']), 16) == correct + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.6.0') def test_bitop_single_string(self, r): test_str = b'\x01\x02\xFF' @@ -697,6 +742,7 @@ def test_bitop_single_string(self, r): assert r['res2'] == test_str assert r['res3'] == test_str + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.6.0') def test_bitop_string_operands(self, r): r['a'] = b'\x01\x02\xFF\xFF' @@ -708,6 +754,7 @@ def test_bitop_string_operands(self, r): assert int(binascii.hexlify(r['res2']), 16) == 0x0102FFFF assert int(binascii.hexlify(r['res3']), 16) == 0x000000FF + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.7') def test_bitpos(self, r): key = 'key:bitpos' @@ -730,6 +777,7 @@ def test_bitpos_wrong_arguments(self, r): with pytest.raises(exceptions.RedisError): r.bitpos(key, 7) == 12 + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_copy(self, r): assert r.copy("a", "b") == 0 @@ -738,6 +786,7 @@ def test_copy(self, r): assert r.get("a") == b"foo" assert r.get("b") == b"foo" + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_copy_and_replace(self, r): r.set("a", "foo1") @@ -745,6 +794,7 @@ def test_copy_and_replace(self, r): assert r.copy("a", "b") == 0 assert r.copy("a", "b", replace=True) == 1 + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_copy_to_another_database(self, request): r0 = _get_client(redis.Redis, request, db=0) @@ -969,6 +1019,7 @@ def test_keys(self, r): assert set(r.keys(pattern='test_*')) == keys_with_underscores assert set(r.keys(pattern='test*')) == keys + @pytest.mark.onlynoncluster def test_mget(self, r): assert r.mget([]) == [] assert r.mget(['a', 'b']) == [None, None] @@ -977,24 +1028,28 @@ def test_mget(self, r): r['c'] = '3' assert r.mget('a', 'other', 'b', 'c') == [b'1', None, b'2', b'3'] + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_lmove(self, r): r.rpush('a', 'one', 'two', 'three', 'four') assert r.lmove('a', 'b') assert r.lmove('a', 'b', 'right', 'left') + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_blmove(self, r): r.rpush('a', 'one', 'two', 'three', 'four') assert r.blmove('a', 'b', 5) assert r.blmove('a', 'b', 1, 'RIGHT', 'LEFT') + @pytest.mark.onlynoncluster def test_mset(self, r): d = {'a': b'1', 'b': b'2', 'c': b'3'} assert r.mset(d) for k, v in d.items(): assert r[k] == v + @pytest.mark.onlynoncluster def test_msetnx(self, r): d = {'a': b'1', 'b': b'2', 'c': b'3'} assert r.msetnx(d) @@ -1079,12 +1134,14 @@ def test_randomkey(self, r): r[key] = 1 assert r.randomkey() in (b'a', b'b', b'c') + @pytest.mark.onlynoncluster def test_rename(self, r): r['a'] = '1' assert r.rename('a', 'b') assert r.get('a') is None assert r['b'] == b'1' + @pytest.mark.onlynoncluster def test_renamenx(self, r): r['a'] = '1' r['b'] = '2' @@ -1190,8 +1247,8 @@ def test_setrange(self, r): @skip_if_server_version_lt('6.0.0') def test_stralgo_lcs(self, r): - key1 = 'key1' - key2 = 'key2' + key1 = '{foo}key1' + key2 = '{foo}key2' value1 = 'ohmytext' value2 = 'mynewtext' res = 'mytext' @@ -1269,6 +1326,7 @@ def test_type(self, r): assert r.type('a') == b'zset' # LIST COMMANDS + @pytest.mark.onlynoncluster def test_blpop(self, r): r.rpush('a', '1', '2') r.rpush('b', '3', '4') @@ -1280,6 +1338,7 @@ def test_blpop(self, r): r.rpush('c', '1') assert r.blpop('c', timeout=1) == (b'c', b'1') + @pytest.mark.onlynoncluster def test_brpop(self, r): r.rpush('a', '1', '2') r.rpush('b', '3', '4') @@ -1291,6 +1350,7 @@ def test_brpop(self, r): r.rpush('c', '1') assert r.brpop('c', timeout=1) == (b'c', b'1') + @pytest.mark.onlynoncluster def test_brpoplpush(self, r): r.rpush('a', '1', '2') r.rpush('b', '3', '4') @@ -1300,6 +1360,7 @@ def test_brpoplpush(self, r): assert r.lrange('a', 0, -1) == [] assert r.lrange('b', 0, -1) == [b'1', b'2', b'3', b'4'] + @pytest.mark.onlynoncluster def test_brpoplpush_empty_string(self, r): r.rpush('a', '') assert r.brpoplpush('a', 'b') == b'' @@ -1403,6 +1464,7 @@ def test_rpop_count(self, r): assert r.rpop('a') is None assert r.rpop('a', 3) is None + @pytest.mark.onlynoncluster def test_rpoplpush(self, r): r.rpush('a', 'a1', 'a2', 'a3') r.rpush('b', 'b1', 'b2', 'b3') @@ -1546,12 +1608,14 @@ def test_scard(self, r): r.sadd('a', '1', '2', '3') assert r.scard('a') == 3 + @pytest.mark.onlynoncluster def test_sdiff(self, r): r.sadd('a', '1', '2', '3') assert r.sdiff('a', 'b') == {b'1', b'2', b'3'} r.sadd('b', '2', '3') assert r.sdiff('a', 'b') == {b'1'} + @pytest.mark.onlynoncluster def test_sdiffstore(self, r): r.sadd('a', '1', '2', '3') assert r.sdiffstore('c', 'a', 'b') == 3 @@ -1560,12 +1624,14 @@ def test_sdiffstore(self, r): assert r.sdiffstore('c', 'a', 'b') == 1 assert r.smembers('c') == {b'1'} + @pytest.mark.onlynoncluster def test_sinter(self, r): r.sadd('a', '1', '2', '3') assert r.sinter('a', 'b') == set() r.sadd('b', '2', '3') assert r.sinter('a', 'b') == {b'2', b'3'} + @pytest.mark.onlynoncluster def test_sinterstore(self, r): r.sadd('a', '1', '2', '3') assert r.sinterstore('c', 'a', 'b') == 0 @@ -1592,6 +1658,7 @@ def test_smismember(self, r): assert r.smismember('a', '1', '4', '2', '3') == result_list assert r.smismember('a', ['1', '4', '2', '3']) == result_list + @pytest.mark.onlynoncluster def test_smove(self, r): r.sadd('a', 'a1', 'a2') r.sadd('b', 'b1', 'b2') @@ -1637,11 +1704,13 @@ def test_srem(self, r): assert r.srem('a', '2', '4') == 2 assert r.smembers('a') == {b'1', b'3'} + @pytest.mark.onlynoncluster def test_sunion(self, r): r.sadd('a', '1', '2') r.sadd('b', '2', '3') assert r.sunion('a', 'b') == {b'1', b'2', b'3'} + @pytest.mark.onlynoncluster def test_sunionstore(self, r): r.sadd('a', '1', '2') r.sadd('b', '2', '3') @@ -1653,6 +1722,7 @@ def test_debug_segfault(self, r): with pytest.raises(NotImplementedError): r.debug_segfault() + @pytest.mark.onlynoncluster @skip_if_server_version_lt('3.2.0') def test_script_debug(self, r): with pytest.raises(NotImplementedError): @@ -1734,6 +1804,7 @@ def test_zcount(self, r): assert r.zcount('a', 1, '(' + str(2)) == 1 assert r.zcount('a', 10, 20) == 0 + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_zdiff(self, r): r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) @@ -1741,6 +1812,7 @@ def test_zdiff(self, r): assert r.zdiff(['a', 'b']) == [b'a3'] assert r.zdiff(['a', 'b'], withscores=True) == [b'a3', b'3'] + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_zdiffstore(self, r): r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) @@ -1762,6 +1834,7 @@ def test_zlexcount(self, r): assert r.zlexcount('a', '-', '+') == 7 assert r.zlexcount('a', '[b', '[f') == 5 + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_zinter(self, r): r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 1}) @@ -1784,6 +1857,7 @@ def test_zinter(self, r): assert r.zinter({'a': 1, 'b': 2, 'c': 3}, withscores=True) \ == [(b'a3', 20), (b'a1', 23)] + @pytest.mark.onlynoncluster def test_zinterstore_sum(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) @@ -1792,6 +1866,7 @@ def test_zinterstore_sum(self, r): assert r.zrange('d', 0, -1, withscores=True) == \ [(b'a3', 8), (b'a1', 9)] + @pytest.mark.onlynoncluster def test_zinterstore_max(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) @@ -1800,6 +1875,7 @@ def test_zinterstore_max(self, r): assert r.zrange('d', 0, -1, withscores=True) == \ [(b'a3', 5), (b'a1', 6)] + @pytest.mark.onlynoncluster def test_zinterstore_min(self, r): r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) r.zadd('b', {'a1': 2, 'a2': 3, 'a3': 5}) @@ -1808,6 +1884,7 @@ def test_zinterstore_min(self, r): assert r.zrange('d', 0, -1, withscores=True) == \ [(b'a1', 1), (b'a3', 3)] + @pytest.mark.onlynoncluster def test_zinterstore_with_weight(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) @@ -1846,6 +1923,7 @@ def test_zrandemember(self, r): # with duplications assert len(r.zrandmember('a', -10)) == 10 + @pytest.mark.onlynoncluster @skip_if_server_version_lt('4.9.0') def test_bzpopmax(self, r): r.zadd('a', {'a1': 1, 'a2': 2}) @@ -1858,6 +1936,7 @@ def test_bzpopmax(self, r): r.zadd('c', {'c1': 100}) assert r.bzpopmax('c', timeout=1) == (b'c', b'c1', 100) + @pytest.mark.onlynoncluster @skip_if_server_version_lt('4.9.0') def test_bzpopmin(self, r): r.zadd('a', {'a1': 1, 'a2': 2}) @@ -1926,6 +2005,7 @@ def test_zrange_params(self, r): # rev assert r.zrange('a', 0, 1, desc=True) == [b'a5', b'a4'] + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_zrangestore(self, r): r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) @@ -2063,6 +2143,7 @@ def test_zscore(self, r): assert r.zscore('a', 'a2') == 2.0 assert r.zscore('a', 'a4') is None + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_zunion(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) @@ -2083,6 +2164,7 @@ def test_zunion(self, r): assert r.zunion({'a': 1, 'b': 2, 'c': 3}, withscores=True) \ == [(b'a2', 5), (b'a4', 12), (b'a3', 20), (b'a1', 23)] + @pytest.mark.onlynoncluster def test_zunionstore_sum(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) @@ -2091,6 +2173,7 @@ def test_zunionstore_sum(self, r): assert r.zrange('d', 0, -1, withscores=True) == \ [(b'a2', 3), (b'a4', 4), (b'a3', 8), (b'a1', 9)] + @pytest.mark.onlynoncluster def test_zunionstore_max(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) @@ -2099,6 +2182,7 @@ def test_zunionstore_max(self, r): assert r.zrange('d', 0, -1, withscores=True) == \ [(b'a2', 2), (b'a4', 4), (b'a3', 5), (b'a1', 6)] + @pytest.mark.onlynoncluster def test_zunionstore_min(self, r): r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 4}) @@ -2107,6 +2191,7 @@ def test_zunionstore_min(self, r): assert r.zrange('d', 0, -1, withscores=True) == \ [(b'a1', 1), (b'a2', 2), (b'a3', 3), (b'a4', 4)] + @pytest.mark.onlynoncluster def test_zunionstore_with_weight(self, r): r.zadd('a', {'a1': 1, 'a2': 1, 'a3': 1}) r.zadd('b', {'a1': 2, 'a2': 2, 'a3': 2}) @@ -2134,6 +2219,7 @@ def test_pfadd(self, r): assert r.pfadd('a', *members) == 0 assert r.pfcount('a') == len(members) + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.9') def test_pfcount(self, r): members = {b'1', b'2', b'3'} @@ -2144,6 +2230,7 @@ def test_pfcount(self, r): assert r.pfcount('b') == len(members_b) assert r.pfcount('a', 'b') == len(members_b.union(members)) + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.9') def test_pfmerge(self, r): mema = {b'1', b'2', b'3'} @@ -2238,8 +2325,9 @@ def test_hmget(self, r): assert r.hmget('a', 'a', 'b', 'c') == [b'1', b'2', b'3'] def test_hmset(self, r): - warning_message = (r'^Redis\.hmset\(\) is deprecated\. ' - r'Use Redis\.hset\(\) instead\.$') + redis_class = type(r).__name__ + warning_message = (r'^{0}\.hmset\(\) is deprecated\. ' + r'Use {0}\.hset\(\) instead\.$'.format(redis_class)) h = {b'a': b'1', b'b': b'2', b'c': b'3'} with pytest.warns(DeprecationWarning, match=warning_message): assert r.hmset('a', h) @@ -2274,6 +2362,7 @@ def test_sort_limited(self, r): r.rpush('a', '3', '2', '1', '4') assert r.sort('a', start=1, num=2) == [b'2', b'3'] + @pytest.mark.onlynoncluster def test_sort_by(self, r): r['score:1'] = 8 r['score:2'] = 3 @@ -2281,6 +2370,7 @@ def test_sort_by(self, r): r.rpush('a', '3', '2', '1') assert r.sort('a', by='score:*') == [b'2', b'3', b'1'] + @pytest.mark.onlynoncluster def test_sort_get(self, r): r['user:1'] = 'u1' r['user:2'] = 'u2' @@ -2288,6 +2378,7 @@ def test_sort_get(self, r): r.rpush('a', '2', '3', '1') assert r.sort('a', get='user:*') == [b'u1', b'u2', b'u3'] + @pytest.mark.onlynoncluster def test_sort_get_multi(self, r): r['user:1'] = 'u1' r['user:2'] = 'u2' @@ -2296,6 +2387,7 @@ def test_sort_get_multi(self, r): assert r.sort('a', get=('user:*', '#')) == \ [b'u1', b'1', b'u2', b'2', b'u3', b'3'] + @pytest.mark.onlynoncluster def test_sort_get_groups_two(self, r): r['user:1'] = 'u1' r['user:2'] = 'u2' @@ -2304,6 +2396,7 @@ def test_sort_get_groups_two(self, r): assert r.sort('a', get=('user:*', '#'), groups=True) == \ [(b'u1', b'1'), (b'u2', b'2'), (b'u3', b'3')] + @pytest.mark.onlynoncluster def test_sort_groups_string_get(self, r): r['user:1'] = 'u1' r['user:2'] = 'u2' @@ -2312,6 +2405,7 @@ def test_sort_groups_string_get(self, r): with pytest.raises(exceptions.DataError): r.sort('a', get='user:*', groups=True) + @pytest.mark.onlynoncluster def test_sort_groups_just_one_get(self, r): r['user:1'] = 'u1' r['user:2'] = 'u2' @@ -2328,6 +2422,7 @@ def test_sort_groups_no_get(self, r): with pytest.raises(exceptions.DataError): r.sort('a', groups=True) + @pytest.mark.onlynoncluster def test_sort_groups_three_gets(self, r): r['user:1'] = 'u1' r['user:2'] = 'u2' @@ -2352,11 +2447,13 @@ def test_sort_alpha(self, r): assert r.sort('a', alpha=True) == \ [b'a', b'b', b'c', b'd', b'e'] + @pytest.mark.onlynoncluster def test_sort_store(self, r): r.rpush('a', '2', '3', '1') assert r.sort('a', store='sorted_values') == 3 assert r.lrange('sorted_values', 0, -1) == [b'1', b'2', b'3'] + @pytest.mark.onlynoncluster def test_sort_all_options(self, r): r['user:1:username'] = 'zeus' r['user:2:username'] = 'titan' @@ -2389,65 +2486,83 @@ def test_sort_issue_924(self, r): r.execute_command('SADD', 'issue#924', 1) r.execute_command('SORT', 'issue#924') + @pytest.mark.onlynoncluster def test_cluster_addslots(self, mock_cluster_resp_ok): assert mock_cluster_resp_ok.cluster('ADDSLOTS', 1) is True + @pytest.mark.onlynoncluster def test_cluster_count_failure_reports(self, mock_cluster_resp_int): assert isinstance(mock_cluster_resp_int.cluster( 'COUNT-FAILURE-REPORTS', 'node'), int) + @pytest.mark.onlynoncluster def test_cluster_countkeysinslot(self, mock_cluster_resp_int): assert isinstance(mock_cluster_resp_int.cluster( 'COUNTKEYSINSLOT', 2), int) + @pytest.mark.onlynoncluster def test_cluster_delslots(self, mock_cluster_resp_ok): assert mock_cluster_resp_ok.cluster('DELSLOTS', 1) is True + @pytest.mark.onlynoncluster def test_cluster_failover(self, mock_cluster_resp_ok): assert mock_cluster_resp_ok.cluster('FAILOVER', 1) is True + @pytest.mark.onlynoncluster def test_cluster_forget(self, mock_cluster_resp_ok): assert mock_cluster_resp_ok.cluster('FORGET', 1) is True + @pytest.mark.onlynoncluster def test_cluster_info(self, mock_cluster_resp_info): assert isinstance(mock_cluster_resp_info.cluster('info'), dict) + @pytest.mark.onlynoncluster def test_cluster_keyslot(self, mock_cluster_resp_int): assert isinstance(mock_cluster_resp_int.cluster( 'keyslot', 'asdf'), int) + @pytest.mark.onlynoncluster def test_cluster_meet(self, mock_cluster_resp_ok): assert mock_cluster_resp_ok.cluster('meet', 'ip', 'port', 1) is True + @pytest.mark.onlynoncluster def test_cluster_nodes(self, mock_cluster_resp_nodes): assert isinstance(mock_cluster_resp_nodes.cluster('nodes'), dict) + @pytest.mark.onlynoncluster def test_cluster_replicate(self, mock_cluster_resp_ok): assert mock_cluster_resp_ok.cluster('replicate', 'nodeid') is True + @pytest.mark.onlynoncluster def test_cluster_reset(self, mock_cluster_resp_ok): assert mock_cluster_resp_ok.cluster('reset', 'hard') is True + @pytest.mark.onlynoncluster def test_cluster_saveconfig(self, mock_cluster_resp_ok): assert mock_cluster_resp_ok.cluster('saveconfig') is True + @pytest.mark.onlynoncluster def test_cluster_setslot(self, mock_cluster_resp_ok): assert mock_cluster_resp_ok.cluster('setslot', 1, 'IMPORTING', 'nodeid') is True + @pytest.mark.onlynoncluster def test_cluster_slaves(self, mock_cluster_resp_slaves): assert isinstance(mock_cluster_resp_slaves.cluster( 'slaves', 'nodeid'), dict) + @pytest.mark.onlynoncluster @skip_if_server_version_lt('3.0.0') def test_readwrite(self, r): assert r.readwrite() + @pytest.mark.onlynoncluster @skip_if_server_version_lt('3.0.0') def test_readonly_invalid_cluster_state(self, r): with pytest.raises(exceptions.RedisError): r.readonly() + @pytest.mark.onlynoncluster @skip_if_server_version_lt('3.0.0') def test_readonly(self, mock_cluster_resp_ok): assert mock_cluster_resp_ok.readonly() is True @@ -2674,6 +2789,7 @@ def test_geosearch_negative(self, r): with pytest.raises(exceptions.DataError): assert r.geosearch('barcelona', member='place3', radius=100, any=1) + @pytest.mark.onlynoncluster @skip_if_server_version_lt('6.2.0') def test_geosearchstore(self, r): values = (2.1909389952632, 41.433791470673, 'place1') + \ @@ -2684,6 +2800,7 @@ def test_geosearchstore(self, r): longitude=2.191, latitude=41.433, radius=1000) assert r.zrange('places_barcelona', 0, -1) == [b'place1'] + @pytest.mark.onlynoncluster @skip_unless_arch_bits(64) @skip_if_server_version_lt('6.2.0') def test_geosearchstore_dist(self, r): @@ -2775,6 +2892,7 @@ def test_georadius_sort(self, r): assert r.georadius('barcelona', 2.191, 41.433, 3000, sort='DESC') == \ [b'place2', b'place1'] + @pytest.mark.onlynoncluster @skip_if_server_version_lt('3.2.0') def test_georadius_store(self, r): values = (2.1909389952632, 41.433791470673, 'place1') + \ @@ -2784,6 +2902,7 @@ def test_georadius_store(self, r): r.georadius('barcelona', 2.191, 41.433, 1000, store='places_barcelona') assert r.zrange('places_barcelona', 0, -1) == [b'place1'] + @pytest.mark.onlynoncluster @skip_unless_arch_bits(64) @skip_if_server_version_lt('3.2.0') def test_georadius_store_dist(self, r): @@ -3614,6 +3733,7 @@ def test_memory_usage(self, r): r.set('foo', 'bar') assert isinstance(r.memory_usage('foo'), int) + @pytest.mark.onlynoncluster @skip_if_server_version_lt('4.0.0') def test_module_list(self, r): assert isinstance(r.module_list(), list) @@ -3626,6 +3746,7 @@ def test_command_count(self, r): assert isinstance(res, int) assert res >= 100 + @pytest.mark.onlynoncluster @skip_if_server_version_lt('4.0.0') def test_module(self, r): with pytest.raises(redis.exceptions.ModuleError) as excinfo: @@ -3678,6 +3799,7 @@ def test_restore(self, r): assert r.restore(key, 0, dumpdata, frequency=5) assert r.get(key) == b'blee!' + @pytest.mark.onlynoncluster @skip_if_server_version_lt('5.0.0') def test_replicaof(self, r): @@ -3826,7 +3948,9 @@ def test_get_pubsub_keys(self, r): commands_parser = CommandsParser(r) args1 = ['PUBLISH', 'foo', 'bar'] args2 = ['PUBSUB NUMSUB', 'foo1', 'foo2', 'foo3'] - args3 = ['SUBSCRIBE', 'foo1', 'foo2', 'foo3'] + args3 = ['PUBSUB channels', '*'] + args4 = ['SUBSCRIBE', 'foo1', 'foo2', 'foo3'] assert commands_parser.get_keys(r, *args1) == ['foo'] assert commands_parser.get_keys(r, *args2) == ['foo1', 'foo2', 'foo3'] - assert commands_parser.get_keys(r, *args3) == ['foo1', 'foo2', 'foo3'] + assert commands_parser.get_keys(r, *args3) == ['*'] + assert commands_parser.get_keys(r, *args4) == ['foo1', 'foo2', 'foo3'] From 37a67bf11ff0808dac5ccd59ec8e27bc279264a6 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Mon, 8 Nov 2021 15:49:38 +0200 Subject: [PATCH 12/22] Adjusted the cluster's pubsub tests to keep the pubsub node connections alive so it wont get cleaned by the GC --- tests/test_cluster.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 4dc6e10491..76c2684da1 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -652,12 +652,14 @@ def test_unlink(self, r): def test_pubsub_channels_merge_results(self, r): nodes = r.get_nodes() channels = [] + pubsub_nodes = [] i = 0 for node in nodes: channel = "foo{0}".format(i) # We will create different pubsub clients where each one is # connected to a different node p = r.pubsub(node) + pubsub_nodes.append(p) p.subscribe(channel) b_channel = channel.encode('utf-8') channels.append(b_channel) @@ -677,12 +679,14 @@ def test_pubsub_channels_merge_results(self, r): def test_pubsub_numsub_merge_results(self, r): nodes = r.get_nodes() + pubsub_nodes = [] channel = "foo" b_channel = channel.encode('utf-8') for node in nodes: # We will create different pubsub clients where each one is # connected to a different node p = r.pubsub(node) + pubsub_nodes.append(p) p.subscribe(channel) # Assert that each node returns that only one client is subscribed sub_chann_num = node.redis_connection.pubsub_numsub(channel) @@ -696,11 +700,13 @@ def test_pubsub_numsub_merge_results(self, r): def test_pubsub_numpat_merge_results(self, r): nodes = r.get_nodes() + pubsub_nodes = [] pattern = "foo*" for node in nodes: # We will create different pubsub clients where each one is # connected to a different node p = r.pubsub(node) + pubsub_nodes.append(p) p.psubscribe(pattern) # Assert that each node returns that only one client is subscribed sub_num_pat = node.redis_connection.pubsub_numpat() From 285e3f0927848175224ad5364a01f584472b63f2 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Mon, 15 Nov 2021 10:37:13 +0200 Subject: [PATCH 13/22] Added a default cluster node and changed the default behvior of all non key-based commands to be executed against the default node, if target_nodes were not passed. All tests were adjusted. --- CONTRIBUTING.md | 8 + README.md | 100 +++++----- docker/base/create_cluster.sh | 2 +- redis/cluster.py | 360 ++++++++++++++++++++++++---------- redis/commands/cluster.py | 266 +++++++++++++++++++------ redis/commands/core.py | 83 ++++++-- redis/exceptions.py | 12 +- tests/conftest.py | 35 ++-- tests/test_cluster.py | 336 +++++++++++++++++++++++++------ tests/test_command_parser.py | 62 ++++++ tests/test_commands.py | 66 +------ tests/test_pubsub.py | 2 + 12 files changed, 961 insertions(+), 371 deletions(-) create mode 100644 tests/test_command_parser.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index af067e7fdf..fe37ff9abe 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -68,6 +68,14 @@ configuration](https://redis.io/topics/sentinel). ## Testing +Call `invoke tests` to run all tests, or `invoke all-tests` to run linters +tests as well. With the 'tests' and 'all-tests' targets, all Redis and +RedisCluster tests will be run. + +It is possible to run only Redis client tests (with cluster mode disabled) by +using `invoke redis-tests`; similarly, RedisCluster tests can be run by using +`invoke cluster-tests`. + Each run of tox starts and stops the various dockers required. Sometimes things get stuck, an `invoke clean` can help. diff --git a/README.md b/README.md index 8b42b65d26..d3fd0aaebf 100644 --- a/README.md +++ b/README.md @@ -939,17 +939,19 @@ C 3 redis-py is now supports cluster mode and provides a client for [Redis Cluster](). -The cluster client is based on [redis-py-cluster](https://github.com/Grokzen/redis-py-cluster) -by Grokzen, with a lot of added and -changed functionality. +The cluster client is based on Grokzen's +[redis-py-cluster](https://github.com/Grokzen/redis-py-cluster), has added bug +fixes, and now supersedes that library. Support for these changes is thanks to +his contributions. + **Create RedisCluster:** -Connecting redis-py to the Redis Cluster instance(s) is easy. -RedisCluster requires at least one node to discover the whole cluster nodes, -and there is multiple ways of creating a RedisCluster instance: +Connecting redis-py to a Redis Cluster instance(s) requires at a minimum a +single node for cluster discovery. There are multiple ways in which a cluster +instance can be created: -- Use the 'host' and 'port' arguments: +- Using 'host' and 'port' arguments: ``` pycon >>> from redis.cluster import RedisCluster as Redis @@ -957,14 +959,14 @@ and there is multiple ways of creating a RedisCluster instance: >>> print(rc.get_nodes()) [[host=127.0.0.1,port=6379,name=127.0.0.1:6379,server_type=primary,redis_connection=Redis>>], [host=127.0.0.1,port=6378,name=127.0.0.1:6378,server_type=primary,redis_connection=Redis>>], [host=127.0.0.1,port=6377,name=127.0.0.1:6377,server_type=replica,redis_connection=Redis>>]] ``` -- Use Redis URL: +- Using the Redis URL specification: ``` pycon >>> from redis.cluster import RedisCluster as Redis >>> rc = Redis.from_url("redis://localhost:6379/0") ``` -- Use ClusterNode(s): +- Directly, via the ClusterNode class: ``` pycon >>> from redis.cluster import RedisCluster as Redis @@ -987,13 +989,17 @@ RedisCluster instance can be directly used to execute Redis commands. When a command is being executed through the cluster instance, the target node(s) will be internally determined. When using a key-based command, the target node will be the node that holds the key's slot. -Cluster management commands or other cluster commands have predefined node -group targets (all-primaries, all-nodes, random-node, all-replicas), which are -outlined in the command’s function documentation. -For example, ‘KEYS’ command will be sent to all primaries and return all keys -in the cluster, and ‘CLUSTER NODES’ command will be sent to a random node. -Other management commands will require you to pass the target node/s to execute -the command on. +Cluster management commands and other commands that are not key-based have a +parameter called 'target_nodes' where you can specify which nodes to execute +the command on. In the absence of target_nodes, the command will be executed +on the default cluster node. As part of cluster instance initialization, the +cluster's default node is randomly selected from the cluster's primaries, and +will be updated upon reinitialization. Using r.get_default_node(), you can +get the cluster's default node, or you can change it using the +'set_default_node' method. + +The 'target_nodes' parameter is explained in the following section, +'Specifying Target Nodes'. ``` pycon >>> # target-nodes: the node that holds 'foo1's key slot @@ -1003,20 +1009,18 @@ the command on. >>> # target-nodes: the node that holds 'foo1's key slot >>> print(rc.get('foo1')) b'bar' - >>> # target-nodes: all-primaries + >>> # target-node: default-node >>> print(rc.keys()) - [b'foo1', b'foo2'] - >>> # target-nodes: all-nodes - >>> rc.flushall() + [b'foo1'] + >>> # target-node: default-node + >>> rc.ping() ``` **Specifying Target Nodes:** -As mentioned above, some RedisCluster commands will require you to provide the -target node/s that you want to execute the command on, and in other cases, the -target node will be determined by the client itself. That being said, ALL -RedisCluster commands can be executed against a specific node or a group of -nodes by passing the command kwarg `target_nodes`. +As mentioned above, all non key-based RedisCluster commands accept the kwarg +parameter 'target_nodes' that specifies the node/nodes that the command should +be executed on. The best practice is to specify target nodes using RedisCluster class's node flags: PRIMARIES, REPLICAS, ALL_NODES, RANDOM. When a nodes flag is passed along with a command, it will be internally resolved to the relevant node/s. @@ -1027,13 +1031,14 @@ and attempt to retry executing the command. ``` pycon >>> from redis.cluster import RedisCluster as Redis >>> # run cluster-meet command on all of the cluster's nodes - >>> rc.cluster_meet(Redis.ALL_NODES, '127.0.0.1', 6379) + >>> rc.cluster_meet('127.0.0.1', 6379, target_nodes=Redis.ALL_NODES) >>> # ping all replicas - >>> rc.ping(Redis.REPLICAS) + >>> rc.ping(target_nodes=Redis.REPLICAS) >>> # ping a specific node - >>> rc.ping(Redis.RANDOM) - >>> # ping all nodes in the cluster, default command behavior - >>> rc.ping() + >>> rc.ping(target_nodes=Redis.RANDOM) + >>> # get the keys from all cluster nodes + >>> rc.keys(target_nodes=Redis.ALL_NODES) + [b'foo1', b'foo2'] >>> # execute bgsave in all primaries >>> rc.bgsave(Redis.PRIMARIES) ``` @@ -1047,15 +1052,15 @@ the relevant cluster or connection error will be returned. ``` pycon >>> node = rc.get_node('localhost', 6379) >>> # Get the keys only for that specific node - >>> rc.keys(node) + >>> rc.keys(target_nodes=node) >>> # get Redis info from a subset of primaries >>> subset_primaries = [node for node in rc.get_primaries() if node.port > 6378] - >>> rc.info(subset_primaries) + >>> rc.info(target_nodes=subset_primaries) ``` -In addition, you can use the RedisCluster instance to obtain the Redis instance -of a specific node and execute commands on that node directly. The Redis client, -however, cannot handle cluster failures and retries. +In addition, the RedisCluster instance can query the Redis instance of a +specific node and execute commands on that node directly. The Redis client, +however, does not handle cluster failures and retries. ``` pycon >>> cluster_node = rc.get_node(host='localhost', port=6379) @@ -1107,12 +1112,12 @@ first command execution. The node will be determined by: *Known limitations with pubsub:* -Pattern subscribe and publish do not work properly because if we hash a pattern -like fo* we will get a keyslot for that string but there is a endless -possibilities of channel names based on that pattern that we can’t know in -advance. This feature is not limited but the commands is not recommended to use -right now. -See [redis-py-cluster documentaion](https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html) +Pattern subscribe and publish do not currently work properly due to key slots. +If we hash a pattern like fo* we will receive a keyslot for that string but +there are endless possibilities for channel names based on this pattern - +unknowable in advance. This feature is not disabled but the commands are not +currently recommended for use. +See [redis-py-cluster documentation](https://redis-py-cluster.readthedocs.io/en/stable/pubsub.html) for more. ``` pycon @@ -1126,19 +1131,20 @@ See [redis-py-cluster documentaion](https://redis-py-cluster.readthedocs.io/en/s **Read Only Mode** By default, Redis Cluster always returns MOVE redirection response on accessing -a replica node. You can overcome this limitation and scale read commands with -READONLY mode. +a replica node. You can overcome this limitation and scale read commands by +triggering READONLY mode. To enable READONLY mode pass read_from_replicas=True to RedisCluster constructor. When set to true, read commands will be assigned between the primary and its replications in a Round-Robin manner. -You could also enable READONLY mode in runtime by running readonly() method, -or disable it with readwrite(). +READONLY mode can be set at runtime by calling the readonly() method with +target_nodes='replicas', and read-write access can be restored by calling the +readwrite() method. ``` pycon >>> from cluster import RedisCluster as Redis - # Use 'debug' mode to print the node that the command is executed on + # Use 'debug' log level to print the node that the command is executed on >>> rc_readonly = Redis(startup_nodes=startup_nodes, read_from_replicas=True, debug=True) >>> rc_readonly.set('{foo}1', 'bar1') @@ -1148,7 +1154,7 @@ or disable it with readwrite(). # set command would be directed only to the slot's primary node >>> rc_readonly.set('{foo}2', 'bar2') # reset READONLY flag - >>> rc_readonly.readwrite() + >>> rc_readonly.readwrite(target_nodes='replicas') # now the get command would be directed only to the slot's primary node >>> rc_readonly.get('{foo}1') ``` diff --git a/docker/base/create_cluster.sh b/docker/base/create_cluster.sh index 28aa3b1b8d..82a79c80da 100644 --- a/docker/base/create_cluster.sh +++ b/docker/base/create_cluster.sh @@ -9,7 +9,7 @@ for PORT in $(seq 16379 16384); do touch /nodes/$PORT/redis.conf fi cat << EOF >> /nodes/$PORT/redis.conf -port $PORT +port ${PORT} cluster-enabled yes daemonize yes logfile /redis.log diff --git a/redis/cluster.py b/redis/cluster.py index df9abcbd3a..ee40acfdae 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -1,9 +1,9 @@ import copy +import logging import random import socket import time import threading -import warnings import sys from collections import OrderedDict @@ -38,6 +38,8 @@ safe_str ) +log = logging.getLogger(__name__) + def get_node_name(host, port): return '{0}:{1}'.format(host, port) @@ -92,7 +94,7 @@ def fix_server(*args): PRIMARY = "primary" REPLICA = "replica" -SLOT_ID = 'slot-id' +SLOT_ID = "slot-id" REDIS_ALLOWED_KEYS = ( "charset", @@ -109,6 +111,7 @@ def fix_server(*args): "redis_connect_func", "password", "port", + "retry", "retry_on_timeout", "socket_connect_timeout", "socket_keepalive", @@ -200,16 +203,18 @@ class ClusterParser(DefaultParser): class RedisCluster(ClusterCommands, object): RedisClusterRequestTTL = 16 - PRIMARIES = "all-primaries" - REPLICAS = "all-replicas" - ALL_NODES = "all-nodes" + PRIMARIES = "primaries" + REPLICAS = "replicas" + ALL_NODES = "all" RANDOM = "random" + DEFAULT_NODE = "default-node" NODE_FLAGS = { PRIMARIES, REPLICAS, ALL_NODES, - RANDOM + RANDOM, + DEFAULT_NODE } COMMAND_FLAGS = dict_merge( @@ -227,12 +232,7 @@ class RedisCluster(ClusterCommands, object): "PUBSUB NUMSUB", "PING", "INFO", - "SHUTDOWN" - ], - ALL_NODES, - ), - list_keys_to_dict( - [ + "SHUTDOWN", "KEYS", "SCAN", "FLUSHALL", @@ -257,21 +257,14 @@ class RedisCluster(ClusterCommands, object): "CLIENT GETREDIR", "CLIENT INFO", "CLIENT KILL" - ], - PRIMARIES, - ), - list_keys_to_dict( - [ "READONLY", "READWRITE", - ], - REPLICAS, - ), - list_keys_to_dict( - [ "CLUSTER INFO", + "CLUSTER MEET", "CLUSTER NODES", "CLUSTER REPLICAS", + "CLUSTER RESET", + "CLUSTER SET-CONFIG-EPOCH", "CLUSTER SLOTS", "CLUSTER COUNT-FAILURE-REPORTS", "CLUSTER KEYSLOT", @@ -281,10 +274,11 @@ class RedisCluster(ClusterCommands, object): "CONFIG GET", "DEBUG", "RANDOMKEY", - "STRALGO", + "READONLY", + "READWRITE", "TIME", ], - RANDOM, + DEFAULT_NODE, ), list_keys_to_dict( [ @@ -367,10 +361,13 @@ def __init__( reinitialize_steps=10, read_from_replicas=False, url=None, - debug=False, + retry_on_timeout=False, + retry=None, **kwargs ): """ + Initialize a new RedisCluster client. + :startup_nodes: 'list[ClusterNode]' List of nodes from which initial bootstrapping can be done :host: 'str' @@ -396,9 +393,11 @@ def __init__( :cluster_error_retry_attempts: 'int' Retry command execution attempts when encountering ClusterDownError or ConnectionError - :debug: - Add prints to debug the RedisCluster client - + :retry_on_timeout: 'bool' + To specify a retry policy, first set `retry_on_timeout` to `True` + then set `retry` to a valid `Retry` object + :retry: 'Retry' + a `Retry` object :**kwargs: Extra arguments that will be sent into Redis instance when created (See Official redis-py doc for supported kwargs @@ -407,6 +406,7 @@ def __init__( RedisClusterException: - db (Redis do not support database SELECT in cluster mode) """ + log.info("Creating a new instance of RedisCluster client") if startup_nodes is None: startup_nodes = [] @@ -417,6 +417,10 @@ def __init__( "Argument 'db' is not possible to use in cluster mode" ) + if retry_on_timeout: + kwargs.update({'retry_on_timeout': retry_on_timeout, + 'retry': retry}) + # Get the startup node/s from_url = False if url is not None: @@ -445,7 +449,7 @@ def __init__( "2. list of startup nodes, for example:\n" " RedisCluster(startup_nodes=[ClusterNode('localhost', 6379)," " ClusterNode('localhost', 6378)])") - + log.debug("startup_nodes : {0}".format(startup_nodes)) # Update the connection arguments # Whenever a new connection is established, RedisCluster's on_connect # method should be run @@ -463,7 +467,6 @@ def __init__( self.cluster_error_retry_attempts = cluster_error_retry_attempts self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.node_flags = self.__class__.NODE_FLAGS.copy() - self.debug_mode = debug self.read_from_replicas = read_from_replicas self.reinitialize_counter = 0 self.reinitialize_steps = reinitialize_steps @@ -495,7 +498,11 @@ def __del__(self): def disconnect_connection_pools(self): for node in self.get_nodes(): if node.redis_connection: - node.redis_connection.connection_pool.disconnect() + try: + node.redis_connection.connection_pool.disconnect() + except OSError: + # Client was already disconnected. do nothing + pass @classmethod def from_url(cls, url, **kwargs): @@ -583,6 +590,42 @@ def get_random_node(self): def get_nodes(self): return list(self.nodes_manager.nodes_cache.values()) + def get_node_from_key(self, key, replica=False): + """ + Get the node that holds the key's slot + """ + slot = self.keyslot(key) + slot_cache = self.nodes_manager.slots_cache.get(slot) + if slot_cache is None or len(slot_cache) == 0: + raise SlotNotCoveredError( + 'Slot "{0}" is not covered by the cluster.'.format(slot) + ) + node_idx = 0 + if replica and len(self.nodes_manager.slots_cache[slot]) > 1: + node_idx = 1 + + return slot_cache[node_idx] + + def get_default_node(self): + """ + Get the cluster's default node + """ + return self.nodes_manager.default_node + + def set_default_node(self, node): + """ + Set the default node of the cluster. + :param node: 'ClusterNode' + :return True if the default node was set, else False + """ + if node is None or self.get_node(node_name=node.name) is None: + log.info("The requested node does not exist in the cluster, so " + "the default node was not changed.") + return False + self.nodes_manager.default_node = node + log.info("Changed the default cluster node to {0}".format(node)) + return True + def pubsub(self, node=None, host=None, port=None, **kwargs): """ Allows passing a ClusterNode, or host&port, to get a pubsub instance @@ -624,23 +667,33 @@ def _determine_nodes(self, *args, **kwargs): # nodes flag passed by the user command_flag = nodes_flag else: - # get the predefined nodes group for this command + # get the nodes group for this command if it was predefined command_flag = self.command_flags.get(command) - + if command_flag: + log.debug("Target node/s for {0}: {1}". + format(command, command_flag)) if command_flag == self.__class__.RANDOM: + # return a random node return [self.get_random_node()] elif command_flag == self.__class__.PRIMARIES: + # return all primaries return self.get_primaries() elif command_flag == self.__class__.REPLICAS: + # return all replicas return self.get_replicas() elif command_flag == self.__class__.ALL_NODES: + # return all nodes return self.get_nodes() + elif command_flag == self.__class__.DEFAULT_NODE: + # return the cluster's default node + return [self.nodes_manager.default_node] else: # get the node that holds the key's slot slot = self.determine_slot(*args) - return [self.nodes_manager. - get_node_from_slot(slot, self.read_from_replicas - and command in READ_COMMANDS)] + node = self.nodes_manager.get_node_from_slot( + slot, self.read_from_replicas and command in READ_COMMANDS) + log.debug("Target for {0}: slot {1}".format(args, slot)) + return [node] def _should_reinitialized(self): # In order not to reinitialize the cluster, the user can set @@ -658,16 +711,24 @@ def keyslot(self, key): k = self.encoder.encode(key) return key_slot(k) + def _get_command_keys(self, *args): + """ + Get the keys in the command. If the command has no keys in in, None is + returned. + """ + redis_conn = self.get_default_node().redis_connection + return self.commands_parser.get_keys(redis_conn, *args) + def determine_slot(self, *args): """ - figure out what slot based on command and args + Figure out what slot based on command and args """ if self.command_flags.get(args[0]) == SLOT_ID: # The command contains the slot ID return args[1] - redis_conn = self.get_random_node().redis_connection - keys = self.commands_parser.get_keys(redis_conn, *args) + # Get the keys in the command + keys = self._get_command_keys(*args) if keys is None or len(keys) == 0: raise RedisClusterException( "No way to dispatch this command to Redis Cluster. " @@ -799,8 +860,7 @@ def _execute_command(self, target_node, *args, **kwargs): command in READ_COMMANDS) moved = False - if self.debug_mode: - print("Executing command {0} on target node: {1} {2}". + log.debug("Executing command {0} on target node: {1} {2}". format(command, target_node.server_type, target_node.name)) redis_node = self.get_redis_connection(target_node) @@ -819,10 +879,10 @@ def _execute_command(self, target_node, *args, **kwargs): return response except (RedisClusterException, BusyLoadingError): - warnings.warn("RedisClusterException || BusyLoadingError") + log.exception("RedisClusterException || BusyLoadingError") raise except ConnectionError: - warnings.warn("ConnectionError") + log.exception("ConnectionError") # ConnectionError can also be raised if we couldn't get a # connection from the pool before timing out, so check that # this is an actual connection before attempting to disconnect. @@ -842,7 +902,7 @@ def _execute_command(self, target_node, *args, **kwargs): self.nodes_manager.initialize() raise except TimeoutError: - warnings.warn("TimeoutError") + log.exception("TimeoutError") if connection is not None: connection.disconnect() @@ -857,7 +917,7 @@ def _execute_command(self, target_node, *args, **kwargs): # same client object is shared between multiple threads. To # reduce the frequency you can set this variable in the # RedisCluster constructor. - warnings.warn("MovedError") + log.exception("MovedError") self.reinitialize_counter += 1 if self._should_reinitialized(): self.nodes_manager.initialize() @@ -865,17 +925,17 @@ def _execute_command(self, target_node, *args, **kwargs): self.nodes_manager.update_moved_exception(e) moved = True except TryAgainError: - warnings.warn("TryAgainError") + log.exception("TryAgainError") if ttl < self.RedisClusterRequestTTL / 2: time.sleep(0.05) except AskError as e: - warnings.warn("AskError") + log.exception("AskError") redirect_addr = get_node_name(host=e.host, port=e.port) asking = True except ClusterDownError as e: - warnings.warn("ClusterDownError") + log.exception("ClusterDownError") # ClusterDownError can occur during a failover and to get # self-healed, we will try to reinitialize the cluster layout # and retry executing the command @@ -884,10 +944,10 @@ def _execute_command(self, target_node, *args, **kwargs): raise e except ResponseError as e: message = e.__str__() - warnings.warn("ResponseError: {0}".format(message)) + log.exception("ResponseError: {0}".format(message)) raise e except BaseException as e: - warnings.warn("BaseException") + log.exception("BaseException") if connection: connection.disconnect() raise e @@ -978,6 +1038,7 @@ def __init__(self, startup_nodes, from_url=False, self.nodes_cache = {} self.slots_cache = {} self.startup_nodes = {} + self.default_node = None self.populate_startup_nodes(startup_nodes) self.from_url = from_url self._require_full_coverage = require_full_coverage @@ -991,26 +1052,37 @@ def __init__(self, startup_nodes, from_url=False, self.initialize() def get_node(self, host=None, port=None, node_name=None): - if node_name is None and (host is None or port is None): - warnings.warn( - "get_node requires one of the followings: " + """ + Get the requested node from the cluster's nodes. + nodes. + :return: ClusterNode if the node exists, else None + """ + if host and port: + # the user passed host and port + if host == "localhost": + host = socket.gethostbyname(host) + return self.nodes_cache.get(get_node_name(host=host, port=port)) + elif node_name: + return self.nodes_cache.get(node_name) + else: + log.error( + "get_node requires one of the following: " "1. node name " "2. host and port" ) return None - if host is not None and port is not None: - if host == "localhost": - host = socket.gethostbyname(host) - node_name = get_node_name(host=host, port=port) - return self.nodes_cache.get(node_name) def update_moved_exception(self, exception): self._moved_exception = exception def _update_moved_slots(self): + """ + Update the slot's node with the redirected one + """ e = self._moved_exception redirected_node = self.get_node(host=e.host, port=e.port) - if redirected_node: + if redirected_node is not None: + # The node already exists if redirected_node.server_type is not PRIMARY: # Update the node's server type redirected_node.server_type = PRIMARY @@ -1031,6 +1103,9 @@ def _update_moved_slots(self): self.slots_cache[e.slot_id].remove(redirected_node) # Override the old primary with the new one self.slots_cache[e.slot_id][0] = redirected_node + if self.default_node == old_primary: + # Update the default node with the new primary + self.default_node = redirected_node else: # The new slot owner is a new server, or a server from a different # shard. We need to remove all current nodes from the slot's list @@ -1057,7 +1132,7 @@ def get_node_from_slot(self, slot, read_from_replicas=False, slot, self._require_full_coverage) ) - if read_from_replicas: + if read_from_replicas is True: # get the server index in a Round-Robin manner primary_name = self.slots_cache[slot][0].name node_idx = self.read_load_balancer.get_server_index( @@ -1078,6 +1153,11 @@ def get_node_from_slot(self, slot, read_from_replicas=False, return self.slots_cache[slot][node_idx] def get_nodes_by_server_type(self, server_type): + """ + Get all nodes with the specified server type + :param server_type: 'primary' or 'replica' + :return: list of ClusterNode + """ return [ node for node in self.nodes_cache.values() @@ -1140,10 +1220,7 @@ def create_redis_node(self, host, port, **kwargs): # Create a redis node with a costumed connection pool kwargs.update({"host": host}) kwargs.update({"port": port}) - connection_pool = ConnectionPool(**kwargs) - r = Redis( - connection_pool=connection_pool - ) + r = Redis(connection_pool=ConnectionPool(**kwargs)) else: r = Redis( host=host, @@ -1172,17 +1249,21 @@ def initialize(self): # Create a new Redis connection and let Redis decode the # responses so we won't need to handle that copy_kwargs = copy.deepcopy(kwargs) - copy_kwargs.update({"decode_responses": True}) - copy_kwargs.update({"encoding": "utf-8"}) + copy_kwargs.update({"decode_responses": True, + "encoding": "utf-8"}) r = self.create_redis_node( startup_node.host, startup_node.port, **copy_kwargs) self.startup_nodes[startup_node.name].redis_connection = r cluster_slots = r.execute_command("CLUSTER SLOTS") startup_nodes_reachable = True - except (ConnectionError, TimeoutError): + except (ConnectionError, TimeoutError) as e: + msg = e.__str__ + log.exception('An exception occurred while trying to' + ' initialize the cluster using the seed node' + ' {0}:\n{1}'.format(startup_node.name, msg)) continue except ResponseError as e: - warnings.warn( + log.exception( 'ReseponseError sending "cluster slots" to redis server') # Isn't a cluster connection, so it won't parse these @@ -1204,6 +1285,11 @@ def initialize(self): startup_node, message) ) + # CLUSTER SLOTS command results in the following output: + # [[slot_section[from_slot,to_slot,master,replica1,...,replicaN]]] + # where each node contains the following list: [IP, port, node_id] + # Therefore, cluster_slots[0][2][0] will be the IP address of the + # primary node of the first slot section. # If there's only one server in the cluster, its ``host`` is '' # Fix it to the host in startup_nodes if (len(cluster_slots) == 1 @@ -1270,48 +1356,54 @@ def initialize(self): self.create_redis_connections(list(tmp_nodes_cache.values())) fully_covered = self.check_slots_coverage(tmp_slots) - if not fully_covered: - if self._require_full_coverage: - # Despite the requirement that the slots be covered, there - # isn't a full coverage + # Check if the slots are not fully covered + if not fully_covered and self._require_full_coverage: + # Despite the requirement that the slots be covered, there + # isn't a full coverage + raise RedisClusterException( + 'All slots are not covered after query all startup_nodes.' + ' {0} of {1} covered...'.format( + len(self.slots_cache), REDIS_CLUSTER_HASH_SLOTS) + ) + elif not fully_covered and not self._require_full_coverage: + # The user set require_full_coverage to False. + # In case of full coverage requirement in the cluster's Redis + # configurations, we will raise an exception. Otherwise, we may + # continue with partial coverage. + # see Redis Cluster configuration parameters in + # https://redis.io/topics/cluster-tutorial + if not self._skip_full_coverage_check and \ + self.cluster_require_full_coverage(tmp_nodes_cache): raise RedisClusterException( - 'All slots are not covered after query all startup_nodes.' + 'Not all slots are covered but the cluster\'s ' + 'configuration requires full coverage. Set ' + 'cluster-require-full-coverage configuration to no on ' + 'all of the cluster nodes if you wish the cluster to ' + 'be able to serve without being fully covered.' ' {0} of {1} covered...'.format( len(self.slots_cache), REDIS_CLUSTER_HASH_SLOTS) ) - else: - # The user set require_full_coverage to False. - # In case of full coverage requirement in the cluster's Redis - # configurations, we will raise an exception. Otherwise, we may - # continue with partial coverage. - # see Redis Cluster configuration parameters in - # https://redis.io/topics/cluster-tutorial - if not self._skip_full_coverage_check and \ - self.cluster_require_full_coverage(tmp_nodes_cache): - raise RedisClusterException( - 'Not all slots are covered but the cluster\'s ' - 'configuration requires full coverage. Set ' - 'cluster-require-full-coverage configuration to no on ' - 'all of the cluster nodes if you wish the cluster to ' - 'be able to serve without being fully covered.' - ' {0} of {1} covered...'.format( - len(self.slots_cache), REDIS_CLUSTER_HASH_SLOTS) - ) # Set the tmp variables to the real variables self.nodes_cache = tmp_nodes_cache self.slots_cache = tmp_slots + # Set the default node + self.default_node = self.get_nodes_by_server_type(PRIMARY)[0] # Populate the startup nodes with all discovered nodes self.populate_startup_nodes(self.nodes_cache.values()) def close(self): + self.default_node = None for node in self.nodes_cache.values(): if node.redis_connection: node.redis_connection.close() def reset(self): - if self.read_load_balancer is not None: + try: self.read_load_balancer.reset() + except TypeError: + # The read_load_balancer is None, do nothing + pass class ClusterPubSub(PubSub): @@ -1332,21 +1424,71 @@ def __init__(self, redis_cluster, node=None, host=None, port=None, 1. Hashing the channel name in the request to find its keyslot 2. Selecting a node that handles the keyslot: If read_from_replicas is set to true, a replica can be selected. + + :type redis_cluster: RedisCluster + :type node: ClusterNode + :type host: str + :type port: int """ + log.info("Creating new instance of ClusterPubSub") self.node = None - connection_pool = None - if host is not None and port is not None: - node = redis_cluster.get_node(host=host, port=port) - self.node = node - if node is not None: - if not isinstance(node, ClusterNode): - raise DataError("'node' must be a ClusterNode") - connection_pool = redis_cluster.get_redis_connection(node). \ - connection_pool + self.set_pubsub_node(redis_cluster, node, host, port) + connection_pool = None if self.node is None else \ + redis_cluster.get_redis_connection(self.node).connection_pool self.cluster = redis_cluster super().__init__(**kwargs, connection_pool=connection_pool, encoder=redis_cluster.encoder) + def set_pubsub_node(self, cluster, node=None, host=None, port=None): + """ + The pubsub node will be set according to the passed node, host and port + When none of the node, host, or port are specified - the node is set + to None and will be determined by the keyslot of the channel in the + first command to be executed. + RedisClusterException will be thrown if the passed node does not exist + in the cluster. + If host is passed without port, or vice versa, a DataError will be + thrown. + :type cluster: RedisCluster + :type node: ClusterNode + :type host: str + :type port: int + """ + if node is not None: + # node is passed by the user + self._raise_on_invalid_node(cluster, node, node.host, node.port) + pubsub_node = node + elif host is not None and port is not None: + # host and port passed by the user + node = cluster.get_node(host=host, port=port) + self._raise_on_invalid_node(cluster, node, host, port) + pubsub_node = node + elif any([host, port]) is True: + # only 'host' or 'port' passed + raise DataError('Passing a host requires passing a port, ' + 'and vice versa') + else: + # nothing passed by the user. set node to None + pubsub_node = None + + self.node = pubsub_node + + def get_pubsub_node(self): + """ + Get the node that is being used as the pubsub connection + """ + return self.node + + def _raise_on_invalid_node(self, redis_cluster, node, host, port): + """ + Raise a RedisClusterException if the node is None or doesn't exist in + the cluster. + """ + if node is None or redis_cluster.get_node(node_name=node.name) is None: + raise RedisClusterException( + "Node {0}:{1} doesn't exist in the cluster" + .format(host, port)) + def execute_command(self, *args, **kwargs): """ Execute a publish/subscribe command. @@ -1404,11 +1546,11 @@ class ClusterPipeline(RedisCluster): def __init__(self, nodes_manager, result_callbacks=None, cluster_response_callbacks=None, startup_nodes=None, read_from_replicas=False, cluster_error_retry_attempts=3, - debug=False, **kwargs): + **kwargs): """ """ + log.info("Creating new instance of ClusterPipeline") self.command_stack = [] - self.debug_mode = debug self.nodes_manager = nodes_manager self.refresh_table_asap = False self.result_callbacks = (result_callbacks or @@ -1466,11 +1608,13 @@ def __bool__(self): def execute_command(self, *args, **kwargs): """ + Wrapper function for pipeline_execute_command """ return self.pipeline_execute_command(*args, **kwargs) def pipeline_execute_command(self, *args, **options): """ + Appends the executed command to the pipeline's command stack """ self.command_stack.append( PipelineCommand(args, options, len(self.command_stack))) @@ -1478,6 +1622,7 @@ def pipeline_execute_command(self, *args, **options): def raise_first_error(self, stack): """ + Raise the first exception on the stack """ for c in stack: r = c.result @@ -1487,6 +1632,7 @@ def raise_first_error(self, stack): def annotate_exception(self, exception, number, command): """ + Provides extra context to the exception prior to it being handled """ cmd = ' '.join(map(safe_str, command)) msg = 'Command # %d (%s) of pipeline caused error: %s' % ( @@ -1495,12 +1641,9 @@ def annotate_exception(self, exception, number, command): def execute(self, raise_on_error=True): """ + Execute all the commands in the current pipeline """ stack = self.command_stack - - if not stack: - return [] - try: return self.send_cluster_commands(stack, raise_on_error) finally: @@ -1555,6 +1698,9 @@ def send_cluster_commands(self, stack, If it reaches the number of times, the command will raises ClusterDownException. """ + if not stack: + return [] + for _ in range(0, self.cluster_error_retry_attempts): try: return self._send_cluster_commands( diff --git a/redis/commands/cluster.py b/redis/commands/cluster.py index 6358546802..1f5f0a2b8c 100644 --- a/redis/commands/cluster.py +++ b/redis/commands/cluster.py @@ -4,7 +4,7 @@ RedisError, ) from redis.crc import key_slot -from .core import DataAccessCommands, PubSubCommands +from .core import DataAccessCommands from .helpers import list_or_args @@ -148,6 +148,24 @@ def unlink(self, *keys): class ClusterManagementCommands: + """ + Redis Cluster management commands + + Commands with the 'target_nodes' argument can be executed on specified + nodes. By default, if target_nodes is not specified, the command will be + executed on the default cluster node. + + :param :target_nodes: type can be one of the followings: + - nodes flag: 'all', 'primaries', 'replicas', 'random' + - 'ClusterNode' + - 'list(ClusterNodes)' + - 'dict(any:clusterNodes)' + + for example: + primary = r.get_primaries()[0] + r.bgsave(target_nodes=primary) + r.bgsave(target_nodes='primaries') + """ def bgsave(self, schedule=True, target_nodes=None): """ Tell the Redis server to save its data to disk. Unlike save(), @@ -321,17 +339,15 @@ def client_unpause(self, target_nodes=None): return self.execute_command('CLIENT UNPAUSE', target_nodes=target_nodes) - def command_count(self): + def command_count(self, target_nodes=None): """ Returns Integer reply of number of total commands in this Redis server. - Send to a random node. """ - return self.execute_command('COMMAND COUNT') + return self.execute_command('COMMAND COUNT', target_nodes=target_nodes) def config_get(self, pattern="*", target_nodes=None): """ Return a dictionary of configuration based on the ``pattern`` - If no target nodes are specified, send to a random node """ return self.execute_command('CONFIG GET', pattern, @@ -359,8 +375,6 @@ def config_set(self, name, value, target_nodes=None): def dbsize(self, target_nodes=None): """ Sums the number of keys in the target nodes' DB. - If no target nodes are specified, send to the entire cluster and sum - the results. :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' The node/s to execute the command on @@ -385,7 +399,7 @@ def echo(self, value, target_nodes): def flushall(self, asynchronous=False, target_nodes=None): """ - Delete all keys in the database on all hosts. + Delete all keys in the database. In cluster mode this method is the same as flushdb ``asynchronous`` indicates whether the operation is @@ -430,6 +444,10 @@ def info(self, section=None, target_nodes=None): section, target_nodes=target_nodes) + def keys(self, pattern='*', target_nodes=None): + "Returns a list of keys matching ``pattern``" + return self.execute_command('KEYS', pattern, target_nodes=target_nodes) + def lastsave(self, target_nodes=None): """ Return a Python datetime object representing the last time the @@ -522,21 +540,71 @@ def ping(self, target_nodes=None): Ping the cluster's servers. If no target nodes are specified, sent to all nodes and returns True if the ping was successful across all nodes. - - :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' - The node/s to execute the command on """ return self.execute_command('PING', target_nodes=target_nodes) - def save(self): + def randomkey(self, target_nodes=None): + """ + Returns the name of a random key" + """ + return self.execute_command('RANDOMKEY', target_nodes=target_nodes) + + def save(self, target_nodes=None): """ Tell the Redis server to save its data to disk, blocking until the save is complete """ - return self.execute_command('SAVE') + return self.execute_command('SAVE', target_nodes=target_nodes) + + def scan(self, cursor=0, match=None, count=None, _type=None, + target_nodes=None): + """ + Incrementally return lists of key names. Also return a cursor + indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` provides a hint to Redis about the number of keys to + return per batch. + + ``_type`` filters the returned values by a particular Redis type. + Stock Redis instances allow for the following types: + HASH, LIST, SET, STREAM, STRING, ZSET + Additionally, Redis modules can expose other types as well. + """ + pieces = [cursor] + if match is not None: + pieces.extend([b'MATCH', match]) + if count is not None: + pieces.extend([b'COUNT', count]) + if _type is not None: + pieces.extend([b'TYPE', _type]) + return self.execute_command('SCAN', *pieces, target_nodes=target_nodes) + + def scan_iter(self, match=None, count=None, _type=None, target_nodes=None): + """ + Make an iterator using the SCAN command so that the client doesn't + need to remember the cursor position. - def shutdown(self, save=False, nosave=False): + ``match`` allows for filtering the keys by pattern + + ``count`` provides a hint to Redis about the number of keys to + return per batch. + + ``_type`` filters the returned values by a particular Redis type. + Stock Redis instances allow for the following types: + HASH, LIST, SET, STREAM, STRING, ZSET + Additionally, Redis modules can expose other types as well. + """ + cursor = '0' + while cursor != 0: + cursor, data = self.scan(cursor=cursor, match=match, + count=count, _type=_type, + target_nodes=target_nodes) + yield from data + + def shutdown(self, save=False, nosave=False, target_nodes=None): """Shutdown the Redis server. If Redis has persistence configured, data will be flushed before shutdown. If the "save" option is set, a data flush will be attempted even if there is no persistence @@ -551,7 +619,7 @@ def shutdown(self, save=False, nosave=False): if nosave: args.append('NOSAVE') try: - self.execute_command(*args) + self.execute_command(*args, target_nodes=target_nodes) except ConnectionError: # a ConnectionError here is expected return @@ -579,11 +647,60 @@ def slowlog_reset(self, target_nodes=None): return self.execute_command('SLOWLOG RESET', target_nodes=target_nodes) + def stralgo(self, algo, value1, value2, specific_argument='strings', + len=False, idx=False, minmatchlen=None, withmatchlen=False, + target_nodes=None): + """ + Implements complex algorithms that operate on strings. + Right now the only algorithm implemented is the LCS algorithm + (longest common substring). However new algorithms could be + implemented in the future. + + ``algo`` Right now must be LCS + ``value1`` and ``value2`` Can be two strings or two keys + ``specific_argument`` Specifying if the arguments to the algorithm + will be keys or strings. strings is the default. + ``len`` Returns just the len of the match. + ``idx`` Returns the match positions in each string. + ``minmatchlen`` Restrict the list of matches to the ones of a given + minimal length. Can be provided only when ``idx`` set to True. + ``withmatchlen`` Returns the matches with the len of the match. + Can be provided only when ``idx`` set to True. + """ + # check validity + supported_algo = ['LCS'] + if algo not in supported_algo: + raise DataError("The supported algorithms are: %s" + % (', '.join(supported_algo))) + if specific_argument not in ['keys', 'strings']: + raise DataError("specific_argument can be only" + " keys or strings") + if len and idx: + raise DataError("len and idx cannot be provided together.") + + pieces = [algo, specific_argument.upper(), value1, value2] + if len: + pieces.append(b'LEN') + if idx: + pieces.append(b'IDX') + try: + int(minmatchlen) + pieces.extend([b'MINMATCHLEN', minmatchlen]) + except TypeError: + pass + if withmatchlen: + pieces.append(b'WITHMATCHLEN') + if specific_argument == 'strings' and target_nodes is None: + target_nodes = 'default-node' + return self.execute_command('STRALGO', *pieces, len=len, idx=idx, + minmatchlen=minmatchlen, + withmatchlen=withmatchlen, + target_nodes=target_nodes) + def time(self, target_nodes=None): """ Returns the server time as a 2-item tuple of ints: (seconds since epoch, microseconds into this second). - If target_nodes are not specified, send to a random node """ return self.execute_command('TIME', target_nodes=target_nodes) @@ -594,16 +711,66 @@ def wait(self, num_replicas, timeout, target_nodes=None): we finally have at least ``num_replicas``, or when the ``timeout`` was reached. - In cluster mode the WAIT command will be sent to all primaries - and the result will be summed up + If more than one target node are passed the result will be summed up """ return self.execute_command('WAIT', num_replicas, timeout, target_nodes=target_nodes) +class ClusterPubSubCommands: + """ + Redis PubSub commands for RedisCluster use. + see https://redis.io/topics/pubsub + """ + def publish(self, channel, message, target_nodes=None): + """ + Publish ``message`` on ``channel``. + Returns the number of subscribers the message was delivered to. + """ + return self.execute_command('PUBLISH', channel, message, + target_nodes=target_nodes) + + def pubsub_channels(self, pattern='*', target_nodes=None): + """ + Return a list of channels that have at least one subscriber + """ + return self.execute_command('PUBSUB CHANNELS', pattern, + target_nodes=target_nodes) + + def pubsub_numpat(self, target_nodes=None): + """ + Returns the number of subscriptions to patterns + """ + return self.execute_command('PUBSUB NUMPAT', target_nodes=target_nodes) + + def pubsub_numsub(self, *args, target_nodes=None): + """ + Return a list of (channel, number of subscribers) tuples + for each channel given in ``*args`` + """ + return self.execute_command('PUBSUB NUMSUB', *args, + target_nodes=target_nodes) + + class ClusterCommands(ClusterManagementCommands, ClusterMultiKeyCommands, - DataAccessCommands, PubSubCommands): + ClusterPubSubCommands, DataAccessCommands): + """ + Redis Cluster commands + + Commands with the 'target_nodes' argument can be executed on specified + nodes. By default, if target_nodes is not specified, the command will be + executed on the default cluster node. + + :param :target_nodes: type can be one of the followings: + - nodes flag: 'all', 'primaries', 'replicas', 'random' + - 'ClusterNode' + - 'list(ClusterNodes)' + - 'dict(any:clusterNodes)' + + for example: + r.cluster_info(target_nodes='all') + """ def cluster_addslots(self, target_node, *slots): """ Assign new hash slots to receiving node. Sends to specified node. @@ -660,16 +827,13 @@ def cluster_failover(self, target_node, option=None): return self.execute_command('CLUSTER FAILOVER', target_nodes=target_node) - def cluster_info(self, target_node=None): + def cluster_info(self, target_nodes=None): """ Provides info about Redis Cluster node state. The command will be sent to a random node in the cluster if no target node is specified. - - :target_node: 'ClusterNode' - The node to execute the command on """ - return self.execute_command('CLUSTER INFO', target_nodes=target_node) + return self.execute_command('CLUSTER INFO', target_nodes=target_nodes) def cluster_keyslot(self, key): """ @@ -678,13 +842,10 @@ def cluster_keyslot(self, key): """ return self.execute_command('CLUSTER KEYSLOT', key) - def cluster_meet(self, target_nodes, host, port): + def cluster_meet(self, host, port, target_nodes=None): """ Force a node cluster to handshake with another node. Sends to specified node. - - :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' - The node/s to execute the command on """ return self.execute_command('CLUSTER MEET', host, port, target_nodes=target_nodes) @@ -700,33 +861,24 @@ def cluster_nodes(self): def cluster_replicate(self, target_nodes, node_id): """ Reconfigure a node as a slave of the specified master node - - :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' - The node/s to execute the command on """ return self.execute_command('CLUSTER REPLICATE', node_id, target_nodes=target_nodes) - def cluster_reset(self, target_nodes, soft=True): + def cluster_reset(self, soft=True, target_nodes=None): """ Reset a Redis Cluster node If 'soft' is True then it will send 'SOFT' argument If 'soft' is False then it will send 'HARD' argument - - :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' - The node/s to execute the command on """ return self.execute_command('CLUSTER RESET', b'SOFT' if soft else b'HARD', target_nodes=target_nodes) - def cluster_save_config(self, target_nodes): + def cluster_save_config(self, target_nodes=None): """ Forces the node to save cluster state on disk - - :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' - The node/s to execute the command on """ return self.execute_command('CLUSTER SAVECONFIG', target_nodes=target_nodes) @@ -737,12 +889,9 @@ def cluster_get_keys_in_slot(self, slot, num_keys): """ return self.execute_command('CLUSTER GETKEYSINSLOT', slot, num_keys) - def cluster_set_config_epoch(self, target_nodes, epoch): + def cluster_set_config_epoch(self, epoch, target_nodes=None): """ Set the configuration epoch in a new node - - :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' - The node/s to execute the command on """ return self.execute_command('CLUSTER SET-CONFIG-EPOCH', epoch, target_nodes=target_nodes) @@ -770,42 +919,37 @@ def cluster_setslot_stable(self, slot_id): """ return self.execute_command('CLUSTER SETSLOT', slot_id, 'STABLE') - def cluster_replicas(self, node_id): + def cluster_replicas(self, node_id, target_nodes=None): """ Provides a list of replica nodes replicating from the specified primary target node. - Sends to random node in the cluster. """ - return self.execute_command('CLUSTER REPLICAS', node_id) + return self.execute_command('CLUSTER REPLICAS', node_id, + target_nodes=target_nodes) - def cluster_slots(self): + def cluster_slots(self, target_nodes=None): """ Get array of Cluster slot to node mappings - - Sends to random node in the cluster """ - return self.execute_command('CLUSTER SLOTS') + return self.execute_command('CLUSTER SLOTS', target_nodes=target_nodes) def readonly(self, target_nodes=None): """ Enables read queries. - The command will be sent to all replica nodes if target_nodes is not - specified. - - :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' - The node/s to execute the command on + The command will be sent to the default cluster node if target_nodes is + not specified. """ - self.read_from_replicas = True + if target_nodes == 'replicas' or target_nodes == 'all': + # read_from_replicas will only be enabled if the READONLY command + # is sent to all replicas + self.read_from_replicas = True return self.execute_command('READONLY', target_nodes=target_nodes) def readwrite(self, target_nodes=None): """ Disables read queries. - The command will be sent to all replica nodes if target_nodes is not - specified. - - :target_nodes: 'ClusterNode' or 'list(ClusterNodes)' - The node/s to execute the command on + The command will be sent to the default cluster node if target_nodes is + not specified. """ # Reset read from replicas flag self.read_from_replicas = False diff --git a/redis/commands/core.py b/redis/commands/core.py index c556908d48..1593a9510d 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -12,8 +12,11 @@ ) -class AclCommands: - # ACL methods +class ACLCommands: + """ + Redis Access Control List (ACL) commands. + see: https://redis.io/topics/acl + """ def acl_cat(self, category=None): """ Returns a list of categories or commands within a category. @@ -262,6 +265,9 @@ def acl_whoami(self): class ManagementCommands: + """ + Redis management commands + """ def bgrewriteaof(self): "Tell the Redis server to rewrite the AOF file from data in memory." return self.execute_command('BGREWRITEAOF') @@ -716,7 +722,9 @@ def wait(self, num_replicas, timeout): class BasicKeyCommands: - # BASIC KEY COMMANDS + """ + Redis basic key-based commands + """ def append(self, key, value): """ Appends the string ``value`` to the value at ``key``. If ``key`` @@ -1344,7 +1352,10 @@ def unlink(self, *names): class ListCommands: - # LIST COMMANDS + """ + Redis commands for List data type. + see: https://redis.io/topics/data-types#lists + """ def blpop(self, keys, timeout=0): """ LPOP a value off of the first non-empty list @@ -1598,7 +1609,10 @@ def sort(self, name, start=None, num=None, by=None, get=None, class ScanCommands: - # SCAN COMMANDS + """ + Redis SCAN commands. + see: https://redis.io/commands/scan + """ def scan(self, cursor=0, match=None, count=None, _type=None): """ Incrementally return lists of key names. Also return a cursor @@ -1747,7 +1761,10 @@ def zscan_iter(self, name, match=None, count=None, class SetCommands: - # SET COMMANDS + """ + Redis commands for Set data type. + see: https://redis.io/topics/data-types#sets + """ def sadd(self, name, *values): """Add ``value(s)`` to set ``name``""" return self.execute_command('SADD', name, *values) @@ -1838,8 +1855,11 @@ def sunionstore(self, dest, keys, *args): return self.execute_command('SUNIONSTORE', dest, *args) -class StreamsCommands: - # STREAMS COMMANDS +class StreamCommands: + """ + Redis commands for Stream data type. + see: https://redis.io/topics/streams-intro + """ def xack(self, name, groupname, *ids): """ Acknowledges the successful processing of one or more messages. @@ -2279,7 +2299,10 @@ def xtrim(self, name, maxlen=None, approximate=True, minid=None, class SortedSetCommands: - # SORTED SET COMMANDS + """ + Redis commands for Sorted Sets data type. + see: https://redis.io/topics/data-types-intro#redis-sorted-sets + """ def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False, gt=None, lt=None): """ @@ -2758,8 +2781,11 @@ def _zaggregate(self, command, dest, keys, aggregate=None, return self.execute_command(*pieces, **options) -class HyperLogLogCommands: - # HYPERLOGLOG COMMANDS +class HyperlogCommands: + """ + Redis commands of HyperLogLogs data type. + see: https://redis.io/topics/data-types-intro#hyperloglogs + """ def pfadd(self, name, *values): "Adds the specified elements to the specified HyperLogLog." return self.execute_command('PFADD', name, *values) @@ -2777,7 +2803,10 @@ def pfmerge(self, dest, *sources): class HashCommands: - # HASH COMMANDS + """ + Redis commands for Hash data type. + see: https://redis.io/topics/data-types-intro#redis-hashes + """ def hdel(self, name, *keys): "Delete ``keys`` from hash ``name``" return self.execute_command('HDEL', name, *keys) @@ -2873,7 +2902,10 @@ def hstrlen(self, name, key): class PubSubCommands: - # PUBSUB COMMANDS + """ + Redis PubSub commands. + see https://redis.io/topics/pubsub + """ def publish(self, channel, message): """ Publish ``message`` on ``channel``. @@ -2902,7 +2934,10 @@ def pubsub_numsub(self, *args): class ScriptCommands: - # SCRIPT COMMANDS + """ + Redis Lua script commands. see: + https://redis.com/ebook/part-3-next-steps/chapter-11-scripting-redis-with-lua/ + """ def eval(self, script, numkeys, *keys_and_args): """ Execute the Lua ``script``, specifying the ``numkeys`` the script @@ -2976,7 +3011,10 @@ def register_script(self, script): class GeoCommands: - # GEO COMMANDS + """ + Redis Geospatial commands. + see: https://redis.com/redis-best-practices/indexing-patterns/geospatial/ + """ def geoadd(self, name, values, nx=False, xx=False, ch=False): """ Add the specified geospatial items to the specified key identified @@ -3272,7 +3310,10 @@ def _geosearchgeneric(self, command, *args, **kwargs): class ModuleCommands: - # MODULE COMMANDS + """ + Redis Module commands. + see: https://redis.io/topics/modules-intro + """ def module_load(self, path, *args): """ Loads the module from ``path``. @@ -3297,7 +3338,9 @@ def module_list(self): class Script: - "An executable Lua script object returned by ``register_script``" + """ + An executable Lua script object returned by ``register_script`` + """ def __init__(self, registered_client, script): self.registered_client = registered_client @@ -3429,9 +3472,9 @@ def execute(self): class DataAccessCommands(BasicKeyCommands, ListCommands, - ScanCommands, SetCommands, StreamsCommands, + ScanCommands, SetCommands, StreamCommands, SortedSetCommands, - HyperLogLogCommands, HashCommands, GeoCommands, + HyperlogCommands, HashCommands, GeoCommands, ): """ A class containing all of the implemented data access redis commands. @@ -3439,7 +3482,7 @@ class DataAccessCommands(BasicKeyCommands, ListCommands, """ -class CoreCommands(AclCommands, DataAccessCommands, ManagementCommands, +class CoreCommands(ACLCommands, DataAccessCommands, ManagementCommands, ModuleCommands, PubSubCommands, ScriptCommands): """ A class containing all of the implemented redis commands. This class is diff --git a/redis/exceptions.py b/redis/exceptions.py index ac88d03bd1..eb6ecc2dc5 100644 --- a/redis/exceptions.py +++ b/redis/exceptions.py @@ -105,9 +105,9 @@ class ClusterDownError(ClusterError, ResponseError): """ Error indicated CLUSTERDOWN error received from cluster. By default Redis Cluster nodes stop accepting queries if they detect there - is at least an hash slot uncovered (no available node is serving it). + is at least a hash slot uncovered (no available node is serving it). This way if the cluster is partially down (for example a range of hash - slots are no longer covered) all the cluster becomes, eventually, + slots are no longer covered) the entire cluster eventually becomes unavailable. It automatically returns available as soon as all the slots are covered again. """ @@ -119,10 +119,10 @@ def __init__(self, resp): class AskError(ResponseError): """ Error indicated ASK error received from cluster. - When a slot is set as MIGRATING, the node will accept all queries that are - about this hash slot, but only if the key in question exists, otherwise the - query is forwarded using a -ASK redirection to the node that is target of - the migration. + When a slot is set as MIGRATING, the node will accept all queries that + pertain to this hash slot, but only if the key in question exists, + otherwise the query is forwarded using a -ASK redirection to the node that + is target of the migration. src node: MIGRATING to dst node get > ASK error ask dst node > ASKING command diff --git a/tests/conftest.py b/tests/conftest.py index 51ebb92407..b19a4fa3e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -148,7 +148,8 @@ def _get_client(cls, request, single_connection_client=True, flushdb=True, redis_url = request.config.getoption("--redis-url") else: redis_url = from_url - if not REDIS_INFO["cluster_enabled"]: + cluster_mode = REDIS_INFO["cluster_enabled"] + if not cluster_mode: url_options = parse_url(redis_url) url_options.update(kwargs) pool = redis.ConnectionPool(**url_options) @@ -160,22 +161,34 @@ def _get_client(cls, request, single_connection_client=True, flushdb=True, client = client.client() if request: def teardown(): - if flushdb: - try: - client.flushdb() - except redis.ConnectionError: - # handle cases where a test disconnected a client - # just manually retry the flushdb - client.flushdb() - client.close() - if not REDIS_INFO["cluster_enabled"]: + if not cluster_mode: + if flushdb: + try: + client.flushdb() + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + client.flushdb() + client.close() client.connection_pool.disconnect() else: - client.disconnect_connection_pools() + cluster_teardown(client, flushdb) request.addfinalizer(teardown) return client +def cluster_teardown(client, flushdb): + if flushdb: + try: + client.flushdb(target_nodes='primaries') + except redis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + client.flushdb(target_nodes='primaries') + client.close() + client.disconnect_connection_pools() + + # specifically set to the zero database, because creating # an index on db != 0 raises a ResponseError in redis @pytest.fixture() diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 76c2684da1..6b4bc05e5f 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -4,6 +4,7 @@ import warnings from time import sleep +from tests.test_pubsub import wait_for_message from unittest.mock import call, patch, DEFAULT, Mock from redis import Redis from redis.cluster import get_node_name, ClusterNode, \ @@ -181,6 +182,9 @@ def ok_response(connection, *args, **options): @pytest.mark.onlycluster class TestRedisClusterObj: + """ + Tests for the RedisCluster class + """ def test_host_port_startup_node(self): """ Test that it is possible to use host & port arguments as startup node @@ -201,7 +205,7 @@ def test_startup_nodes(self): ClusterNode(default_host, port_2)] cluster = get_mocked_redis_client(startup_nodes=startup_nodes) assert cluster.get_node(host=default_host, port=port_1) is not None \ - and cluster.get_node(host=default_host, port=port_2) is not None + and cluster.get_node(host=default_host, port=port_2) is not None def test_empty_startup_nodes(self): """ @@ -289,6 +293,17 @@ def test_execute_command_node_flag_random(self, r): called_count += 1 assert called_count == 1 + def test_execute_command_default_node(self, r): + """ + Test command execution without node flag is being executed on the + default node + """ + def_node = r.get_default_node() + mock_node_resp(def_node, 'PONG') + assert r.ping() is True + conn = def_node.redis_connection.connection + assert conn.read_response.called + @pytest.mark.filterwarnings("ignore:AskError") def test_ask_redirection(self, r): """ @@ -362,6 +377,7 @@ def parse_response_mock(connection, command_name, def initialize_mock(self): # start with all slots mapped to 7006 self.nodes_cache = {node_7006.name: node_7006} + self.default_node = node_7006 self.slots_cache = {} for i in range(0, 16383): @@ -372,6 +388,7 @@ def initialize_mock(self): def map_7007(self): self.nodes_cache = { node_7007.name: node_7007} + self.default_node = node_7007 self.slots_cache = {} for i in range(0, 16383): @@ -404,7 +421,7 @@ def cmd_init_mock(self, r): RedisCluster, request, flushdb=False) assert len(rc.get_nodes()) == 1 assert rc.get_node(node_name=node_7006.name) is not \ - None + None rc.get('foo') @@ -412,7 +429,7 @@ def cmd_init_mock(self, r): # one failed and one successful call assert len(rc.get_nodes()) == 1 assert rc.get_node(node_name=node_7007.name) is not \ - None + None assert rc.get_node(node_name=node_7006.name) is None assert parse_response.failed_calls == 1 assert parse_response.successful_calls == 1 @@ -482,7 +499,7 @@ def test_keyslot(self, r): def test_get_node_name(self): assert get_node_name(default_host, default_port) == \ - "{0}:{1}".format(default_host, default_port) + "{0}:{1}".format(default_host, default_port) def test_all_nodes(self, r): """ @@ -525,7 +542,7 @@ def raise_cluster_down_error(target_node, *args, **kwargs): with pytest.raises(ClusterDownError): rc.get("bar") assert execute_command.failed_calls == \ - rc.cluster_error_retry_attempts + rc.cluster_error_retry_attempts @pytest.mark.filterwarnings("ignore:ConnectionError") def test_connection_error_overreaches_retry_attempts(self): @@ -546,7 +563,7 @@ def raise_conn_error(target_node, *args, **kwargs): with pytest.raises(ConnectionError): rc.get("bar") assert execute_command.failed_calls == \ - rc.cluster_error_retry_attempts + rc.cluster_error_retry_attempts def test_user_on_connect_function(self, request): """ @@ -561,12 +578,53 @@ def on_connect(connection): _get_client(RedisCluster, request, redis_connect_func=mock) assert mock.called is True + def test_set_default_node_success(self, r): + """ + test successful replacement of the default cluster node + """ + default_node = r.get_default_node() + # get a different node + new_def_node = None + for node in r.get_nodes(): + if node != default_node: + new_def_node = node + break + assert r.set_default_node(new_def_node) is True + assert r.get_default_node() == new_def_node + + def test_set_default_node_failure(self, r): + """ + test failed replacement of the default cluster node + """ + default_node = r.get_default_node() + new_def_node = ClusterNode('1.1.1.1', 1111) + assert r.set_default_node(None) is False + assert r.set_default_node(new_def_node) is False + assert r.get_default_node() == default_node + + def test_get_node_from_key(self, r): + """ + Test that get_node_from_key function returns the correct node + """ + key = 'bar' + slot = r.keyslot(key) + slot_nodes = r.nodes_manager.slots_cache.get(slot) + primary = slot_nodes[0] + assert r.get_node_from_key(key, replica=False) == primary + if len(slot_nodes) > 1: + key_node = r.get_node_from_key(key, replica=True) + assert key_node.server_type == 'replica' + assert key_node in slot_nodes + @pytest.mark.onlycluster class TestClusterRedisCommands: + """ + Tests for RedisCluster unique commands + """ def test_case_insensitive_command_names(self, r): assert r.cluster_response_callbacks['cluster addslots'] == \ - r.cluster_response_callbacks['CLUSTER ADDSLOTS'] + r.cluster_response_callbacks['CLUSTER ADDSLOTS'] def test_get_and_set(self, r): # get and set can't be tested independently of each other @@ -597,32 +655,27 @@ def test_mset_nonatomic(self, r): for k, v in d.items(): assert r[k] == v - def test_dbsize(self, r): - d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} - assert r.mset_nonatomic(d) - assert r.dbsize() == len(d) - def test_config_set(self, r): assert r.config_set('slowlog-log-slower-than', 0) def test_cluster_config_resetstat(self, r): - r.ping() - all_info = r.info() + r.ping(target_nodes='all') + all_info = r.info(target_nodes='all') prior_commands_processed = -1 for node_info in all_info.values(): prior_commands_processed = node_info['total_commands_processed'] assert prior_commands_processed >= 1 - r.config_resetstat() - all_info = r.info() + r.config_resetstat(target_nodes='all') + all_info = r.info(target_nodes='all') for node_info in all_info.values(): reset_commands_processed = node_info['total_commands_processed'] assert reset_commands_processed < prior_commands_processed def test_client_setname(self, r): - r.client_setname('redis_py_test') - res = r.client_getname() - for client_name in res.values(): - assert client_name == 'redis_py_test' + node = r.get_random_node() + r.client_setname('redis_py_test', target_nodes=node) + client_name = r.client_getname(target_nodes=node) + assert client_name == 'redis_py_test' def test_exists(self, r): d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} @@ -673,7 +726,7 @@ def test_pubsub_channels_merge_results(self, r): i += 1 # Assert that the cluster's pubsub_channels function returns ALL of # the cluster's channels - result = r.pubsub_channels() + result = r.pubsub_channels(target_nodes='all') result.sort() assert result == channels @@ -696,7 +749,8 @@ def test_pubsub_numsub_merge_results(self, r): assert sub_chann_num == [(b_channel, 1)] # Assert that the cluster's pubsub_numsub function returns ALL clients # subscribed to this channel in the entire cluster - assert r.pubsub_numsub(channel) == [(b_channel, len(nodes))] + assert r.pubsub_numsub(channel, target_nodes='all') == \ + [(b_channel, len(nodes))] def test_pubsub_numpat_merge_results(self, r): nodes = r.get_nodes() @@ -716,7 +770,35 @@ def test_pubsub_numpat_merge_results(self, r): assert sub_num_pat == 1 # Assert that the cluster's pubsub_numsub function returns ALL clients # subscribed to this channel in the entire cluster - assert r.pubsub_numpat() == len(nodes) + assert r.pubsub_numpat(target_nodes='all') == len(nodes) + + @skip_if_server_version_lt('2.8.0') + def test_cluster_pubsub_channels(self, r): + p = r.pubsub() + p.subscribe('foo', 'bar', 'baz', 'quux') + for i in range(4): + assert wait_for_message(p)['type'] == 'subscribe' + expected = [b'bar', b'baz', b'foo', b'quux'] + assert all([channel in r.pubsub_channels(target_nodes='all') + for channel in expected]) + + @skip_if_server_version_lt('2.8.0') + def test_cluster_pubsub_numsub(self, r): + p1 = r.pubsub() + p1.subscribe('foo', 'bar', 'baz') + for i in range(3): + assert wait_for_message(p1)['type'] == 'subscribe' + p2 = r.pubsub() + p2.subscribe('bar', 'baz') + for i in range(2): + assert wait_for_message(p2)['type'] == 'subscribe' + p3 = r.pubsub() + p3.subscribe('baz') + assert wait_for_message(p3)['type'] == 'subscribe' + + channels = [(b'foo', 1), (b'bar', 2), (b'baz', 3)] + assert r.pubsub_numsub('foo', 'bar', 'baz', target_nodes='all') \ + == channels def test_cluster_slots(self, r): mock_all_nodes_resp(r, default_cluster_slots) @@ -725,7 +807,7 @@ def test_cluster_slots(self, r): assert len(default_cluster_slots) == len(cluster_slots) assert cluster_slots.get((0, 8191)) is not None assert cluster_slots.get((0, 8191)).get('primary') == \ - ('127.0.0.1', 7000) + ('127.0.0.1', 7000) def test_cluster_addslots(self, r): node = r.get_random_node() @@ -780,9 +862,9 @@ def test_cluster_keyslot(self, r): assert r.cluster_keyslot('foo') == 12182 def test_cluster_meet(self, r): - node = r.get_random_node() + node = r.get_default_node() mock_node_resp(node, 'OK') - assert r.cluster_meet(node, '127.0.0.1', 6379) is True + assert r.cluster_meet('127.0.0.1', 6379) is True def test_cluster_nodes(self, r): response = ( @@ -808,7 +890,7 @@ def test_cluster_nodes(self, r): assert len(nodes) == 7 assert nodes.get('172.17.0.7:7006') is not None assert nodes.get('172.17.0.7:7006').get('node_id') == \ - "c8253bae761cb1ecb2b61857d85dfe455a0fec8b" + "c8253bae761cb1ecb2b61857d85dfe455a0fec8b" def test_cluster_replicate(self, r): node = r.get_random_node() @@ -816,16 +898,17 @@ def test_cluster_replicate(self, r): mock_all_nodes_resp(r, 'OK') assert r.cluster_replicate(node, 'c8253bae761cb61857d') is True results = r.cluster_replicate(all_replicas, 'c8253bae761cb61857d') - for res in results.values(): - assert res is True + if isinstance(results, dict): + for res in results.values(): + assert res is True + else: + assert results is True def test_cluster_reset(self, r): - node = r.get_random_node() - all_nodes = r.get_nodes() mock_all_nodes_resp(r, 'OK') - assert r.cluster_reset(node) is True - assert r.cluster_reset(node, False) is True - all_results = r.cluster_reset(all_nodes, False) + assert r.cluster_reset() is True + assert r.cluster_reset(False) is True + all_results = r.cluster_reset(False, target_nodes='all') for res in all_results.values(): assert res is True @@ -846,11 +929,9 @@ def test_cluster_get_keys_in_slot(self, r): assert keys == response def test_cluster_set_config_epoch(self, r): - node = r.get_random_node() - all_nodes = r.get_nodes() mock_all_nodes_resp(r, 'OK') - assert r.cluster_set_config_epoch(node, 3) is True - all_results = r.cluster_set_config_epoch(all_nodes, 3) + assert r.cluster_set_config_epoch(3) is True + all_results = r.cluster_set_config_epoch(3, target_nodes='all') for res in all_results.values(): assert res is True @@ -885,30 +966,26 @@ def test_cluster_replicas(self, r): assert replicas.get('127.0.0.1:6377') is not None assert replicas.get('127.0.0.1:6378') is not None assert replicas.get('127.0.0.1:6378').get('node_id') == \ - 'r4xfga22229cf3c652b6fca0d09ff69f3e0d4d' + 'r4xfga22229cf3c652b6fca0d09ff69f3e0d4d' def test_readonly(self): r = get_mocked_redis_client(host=default_host, port=default_port) - node = r.get_random_node() - all_replicas = r.get_replicas() mock_all_nodes_resp(r, 'OK') - assert r.readonly(node) is True - all_replicas_results = r.readonly() + assert r.readonly() is True + all_replicas_results = r.readonly(target_nodes='replicas') for res in all_replicas_results.values(): assert res is True - for replica in all_replicas: + for replica in r.get_replicas(): assert replica.redis_connection.connection.read_response.called def test_readwrite(self): r = get_mocked_redis_client(host=default_host, port=default_port) - node = r.get_random_node() mock_all_nodes_resp(r, 'OK') - all_replicas = r.get_replicas() - assert r.readwrite(node) is True - all_replicas_results = r.readwrite() + assert r.readwrite() is True + all_replicas_results = r.readwrite(target_nodes='replicas') for res in all_replicas_results.values(): assert res is True - for replica in all_replicas: + for replica in r.get_replicas(): assert replica.redis_connection.connection.read_response.called def test_bgsave(self, r): @@ -929,13 +1006,23 @@ def test_info(self, r): assert isinstance(info, dict) assert info['db0']['keys'] == 3 + def _init_slowlog_test(self, r, node): + slowlog_lim = r.config_get('slowlog-log-slower-than', + target_nodes=node) + assert r.config_set('slowlog-log-slower-than', 0, target_nodes=node) \ + is True + return slowlog_lim['slowlog-log-slower-than'] + + def _teardown_slowlog_test(self, r, node, prev_limit): + assert r.config_set('slowlog-log-slower-than', prev_limit, + target_nodes=node) is True + def test_slowlog_get(self, r, slowlog): - assert r.slowlog_reset() unicode_string = chr(3456) + 'abcd' + chr(3421) + node = r.get_node_from_key(unicode_string) + slowlog_limit = self._init_slowlog_test(r, node) + assert r.slowlog_reset(target_nodes=node) r.get(unicode_string) - - slot = r.keyslot(unicode_string) - node = r.nodes_manager.get_node_from_slot(slot) slowlog = r.slowlog_get(target_nodes=node) assert isinstance(slowlog, list) commands = [log['command'] for log in slowlog] @@ -953,15 +1040,19 @@ def test_slowlog_get(self, r, slowlog): # make sure other attributes are typed correctly assert isinstance(slowlog[0]['start_time'], int) assert isinstance(slowlog[0]['duration'], int) + # rollback the slowlog limit to its original value + self._teardown_slowlog_test(r, node, slowlog_limit) def test_slowlog_get_limit(self, r, slowlog): assert r.slowlog_reset() + node = r.get_node_from_key('foo') + slowlog_limit = self._init_slowlog_test(r, node) r.get('foo') - node = r.nodes_manager.get_node_from_slot(key_slot(b'foo')) slowlog = r.slowlog_get(1, target_nodes=node) assert isinstance(slowlog, list) # only one command, based on the number we passed to slowlog_get() assert len(slowlog) == 1 + self._teardown_slowlog_test(r, node, slowlog_limit) def test_slowlog_length(self, r, slowlog): r.get('foo') @@ -1079,9 +1170,9 @@ def test_client_info(self, r): @skip_if_server_version_lt('2.6.9') def test_client_kill(self, r, r2): node = r.get_primaries()[0] - r.client_setname('redis-py-c1') - r2.client_setname('redis-py-c2') - clients = [client for client in r.client_list()[node.name] + r.client_setname('redis-py-c1', target_nodes='all') + r2.client_setname('redis-py-c2', target_nodes='all') + clients = [client for client in r.client_list(target_nodes=node) if client.get('name') in ['redis-py-c1', 'redis-py-c2']] assert len(clients) == 2 clients_by_name = dict([(client.get('name'), client) @@ -1090,7 +1181,7 @@ def test_client_kill(self, r, r2): client_addr = clients_by_name['redis-py-c2'].get('addr') assert r.client_kill(client_addr, target_nodes=node) is True - clients = [client for client in r.client_list()[node.name] + clients = [client for client in r.client_list(target_nodes=node) if client.get('name') in ['redis-py-c1', 'redis-py-c2']] assert len(clients) == 1 assert clients[0].get('name') == 'redis-py-c1' @@ -1538,9 +1629,65 @@ def test_cluster_georadius_store_dist(self, r): # instead of save the geo score, the distance is saved. assert r.zscore('{foo}places_barcelona', 'place1') == 88.05060698409301 + def test_cluster_dbsize(self, r): + d = {'a': b'1', 'b': b'2', 'c': b'3', 'd': b'4'} + assert r.mset_nonatomic(d) + assert r.dbsize(target_nodes='primaries') == len(d) + + def test_cluster_keys(self, r): + assert r.keys() == [] + keys_with_underscores = {b'test_a', b'test_b'} + keys = keys_with_underscores.union({b'testc'}) + for key in keys: + r[key] = 1 + assert set(r.keys(pattern='test_*', target_nodes='primaries')) == \ + keys_with_underscores + assert set(r.keys(pattern='test*', target_nodes='primaries')) == keys + + # SCAN COMMANDS + @skip_if_server_version_lt('2.8.0') + def test_cluster_scan(self, r): + r.set('a', 1) + r.set('b', 2) + r.set('c', 3) + cursor, keys = r.scan(target_nodes='primaries') + assert cursor == 0 + assert set(keys) == {b'a', b'b', b'c'} + _, keys = r.scan(match='a', target_nodes='primaries') + assert set(keys) == {b'a'} + + @skip_if_server_version_lt("6.0.0") + def test_cluster_scan_type(self, r): + r.sadd('a-set', 1) + r.hset('a-hash', 'foo', 2) + r.lpush('a-list', 'aux', 3) + _, keys = r.scan(match='a*', _type='SET', target_nodes='primaries') + assert set(keys) == {b'a-set'} + + @skip_if_server_version_lt('2.8.0') + def test_cluster_scan_iter(self, r): + r.set('a', 1) + r.set('b', 2) + r.set('c', 3) + keys = list(r.scan_iter(target_nodes='primaries')) + assert set(keys) == {b'a', b'b', b'c'} + keys = list(r.scan_iter(match='a', target_nodes='primaries')) + assert set(keys) == {b'a'} + + def test_cluster_randomkey(self, r): + node = r.get_node_from_key('{foo}') + assert r.randomkey(target_nodes=node) is None + for key in ('{foo}a', '{foo}b', '{foo}c'): + r[key] = 1 + assert r.randomkey(target_nodes=node) in \ + (b'{foo}a', b'{foo}b', b'{foo}c') + @pytest.mark.onlycluster class TestNodesManager: + """ + Tests for the NodesManager class + """ def test_load_balancer(self, r): n_manager = r.nodes_manager lb = n_manager.read_load_balancer @@ -1859,3 +2006,76 @@ def cmd_init_mock(self, r): rc = RedisCluster(startup_nodes=[node_1, node_2]) assert rc.get_node(host=default_host, port=7001) is not None assert rc.get_node(host=default_host, port=7002) is not None + + +@pytest.mark.onlycluster +class TestClusterPubSubObject: + """ + Tests for the ClusterPubSub class + """ + def test_init_pubsub_with_host_and_port(self, r): + """ + Test creation of pubsub instance with passed host and port + """ + node = r.get_default_node() + p = r.pubsub(host=node.host, port=node.port) + assert p.get_pubsub_node() == node + + def test_init_pubsub_with_node(self, r): + """ + Test creation of pubsub instance with passed node + """ + node = r.get_default_node() + p = r.pubsub(node=node) + assert p.get_pubsub_node() == node + + def test_init_pubusub_without_specifying_node(self, r): + """ + Test creation of pubsub instance without specifying a node. The node + should be determined based on the keyslot of the first command + execution. + """ + channel_name = 'foo' + node = r.get_node_from_key(channel_name) + p = r.pubsub() + assert p.get_pubsub_node() is None + p.subscribe(channel_name) + assert p.get_pubsub_node() == node + + def test_init_pubsub_with_a_non_existent_node(self, r): + """ + Test creation of pubsub instance with node that doesn't exists in the + cluster. RedisClusterException should be raised. + """ + node = ClusterNode('1.1.1.1', 1111) + with pytest.raises(RedisClusterException): + r.pubsub(node) + + def test_init_pubsub_with_a_non_existent_host_port(self, r): + """ + Test creation of pubsub instance with host and port that don't belong + to a node in the cluster. + RedisClusterException should be raised. + """ + with pytest.raises(RedisClusterException): + r.pubsub(host='1.1.1.1', port=1111) + + def test_init_pubsub_host_or_port(self, r): + """ + Test creation of pubsub instance with host but without port, and vice + versa. DataError should be raised. + """ + with pytest.raises(DataError): + r.pubsub(host='localhost') + + with pytest.raises(DataError): + r.pubsub(port=16379) + + def test_get_redis_connection(self, r): + """ + Test that get_redis_connection() returns the redis connection of the + set pubsub node + """ + node = r.get_default_node() + p = r.pubsub(node=node) + assert p.get_redis_connection() == node.redis_connection diff --git a/tests/test_command_parser.py b/tests/test_command_parser.py new file mode 100644 index 0000000000..ba129ba673 --- /dev/null +++ b/tests/test_command_parser.py @@ -0,0 +1,62 @@ +import pytest + +from redis.commands import CommandsParser + + +class TestCommandsParser: + def test_init_commands(self, r): + commands_parser = CommandsParser(r) + assert commands_parser.commands is not None + assert 'get' in commands_parser.commands + + def test_get_keys_predetermined_key_location(self, r): + commands_parser = CommandsParser(r) + args1 = ['GET', 'foo'] + args2 = ['OBJECT', 'encoding', 'foo'] + args3 = ['MGET', 'foo', 'bar', 'foobar'] + assert commands_parser.get_keys(r, *args1) == ['foo'] + assert commands_parser.get_keys(r, *args2) == ['foo'] + assert commands_parser.get_keys(r, *args3) == ['foo', 'bar', 'foobar'] + + @pytest.mark.filterwarnings("ignore:ResponseError") + def test_get_moveable_keys(self, r): + commands_parser = CommandsParser(r) + args1 = ['EVAL', 'return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}', 2, 'key1', + 'key2', 'first', 'second'] + args2 = ['XREAD', 'COUNT', 2, b'STREAMS', 'mystream', 'writers', 0, 0] + args3 = ['ZUNIONSTORE', 'out', 2, 'zset1', 'zset2', 'WEIGHTS', 2, 3] + args4 = ['GEORADIUS', 'Sicily', 15, 37, 200, 'km', 'WITHCOORD', + b'STORE', 'out'] + args5 = ['MEMORY USAGE', 'foo'] + args6 = ['MIGRATE', '192.168.1.34', 6379, "", 0, 5000, b'KEYS', + 'key1', 'key2', 'key3'] + args7 = ['MIGRATE', '192.168.1.34', 6379, "key1", 0, 5000] + args8 = ['STRALGO', 'LCS', 'STRINGS', 'string_a', 'string_b'] + args9 = ['STRALGO', 'LCS', 'KEYS', 'key1', 'key2'] + + assert commands_parser.get_keys( + r, *args1).sort() == ['key1', 'key2'].sort() + assert commands_parser.get_keys( + r, *args2).sort() == ['mystream', 'writers'].sort() + assert commands_parser.get_keys( + r, *args3).sort() == ['out', 'zset1', 'zset2'].sort() + assert commands_parser.get_keys( + r, *args4).sort() == ['Sicily', 'out'].sort() + assert commands_parser.get_keys(r, *args5).sort() == ['foo'].sort() + assert commands_parser.get_keys( + r, *args6).sort() == ['key1', 'key2', 'key3'].sort() + assert commands_parser.get_keys(r, *args7).sort() == ['key1'].sort() + assert commands_parser.get_keys(r, *args8) is None + assert commands_parser.get_keys( + r, *args9).sort() == ['key1', 'key2'].sort() + + def test_get_pubsub_keys(self, r): + commands_parser = CommandsParser(r) + args1 = ['PUBLISH', 'foo', 'bar'] + args2 = ['PUBSUB NUMSUB', 'foo1', 'foo2', 'foo3'] + args3 = ['PUBSUB channels', '*'] + args4 = ['SUBSCRIBE', 'foo1', 'foo2', 'foo3'] + assert commands_parser.get_keys(r, *args1) == ['foo'] + assert commands_parser.get_keys(r, *args2) == ['foo1', 'foo2', 'foo3'] + assert commands_parser.get_keys(r, *args3) == ['*'] + assert commands_parser.get_keys(r, *args4) == ['foo1', 'foo2', 'foo3'] diff --git a/tests/test_commands.py b/tests/test_commands.py index 04861403f4..0f66dabdc3 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -8,7 +8,6 @@ from redis.client import parse_info from redis import exceptions -from redis.commands import CommandsParser from .conftest import ( _get_client, skip_if_server_version_gte, @@ -557,6 +556,7 @@ def test_config_set(self, r): finally: assert r.config_set('dbfilename', rdbname) + @pytest.mark.onlynoncluster def test_dbsize(self, r): r['a'] = 'foo' r['b'] = 'bar' @@ -1010,6 +1010,7 @@ def test_incrbyfloat(self, r): assert r.incrbyfloat('a', 1.1) == 2.1 assert float(r['a']) == float(2.1) + @pytest.mark.onlynoncluster def test_keys(self, r): assert r.keys() == [] keys_with_underscores = {b'test_a', b'test_b'} @@ -1128,6 +1129,7 @@ def test_hrandfield(self, r): # with duplications assert len(r.hrandfield('key', -10)) == 10 + @pytest.mark.onlynoncluster def test_randomkey(self, r): assert r.randomkey() is None for key in ('a', 'b', 'c'): @@ -1518,6 +1520,7 @@ def test_rpushx(self, r): assert r.lrange('a', 0, -1) == [b'1', b'2', b'3', b'4'] # SCAN COMMANDS + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.0') def test_scan(self, r): r.set('a', 1) @@ -1529,6 +1532,7 @@ def test_scan(self, r): _, keys = r.scan(match='a') assert set(keys) == {b'a'} + @pytest.mark.onlynoncluster @skip_if_server_version_lt("6.0.0") def test_scan_type(self, r): r.sadd('a-set', 1) @@ -1537,6 +1541,7 @@ def test_scan_type(self, r): _, keys = r.scan(match='a*', _type='SET') assert set(keys) == {b'a-set'} + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.0') def test_scan_iter(self, r): r.set('a', 1) @@ -3895,62 +3900,3 @@ def test_floating_point_encoding(self, r): timestamp = 1349673917.939762 r.zadd('a', {'a1': timestamp}) assert r.zscore('a', 'a1') == timestamp - - -class TestCommandsParser: - def test_init_commands(self, r): - commands_parser = CommandsParser(r) - assert commands_parser.commands is not None - assert 'get' in commands_parser.commands - - def test_get_keys_predetermined_key_location(self, r): - commands_parser = CommandsParser(r) - args1 = ['GET', 'foo'] - args2 = ['OBJECT', 'encoding', 'foo'] - args3 = ['MGET', 'foo', 'bar', 'foobar'] - assert commands_parser.get_keys(r, *args1) == ['foo'] - assert commands_parser.get_keys(r, *args2) == ['foo'] - assert commands_parser.get_keys(r, *args3) == ['foo', 'bar', 'foobar'] - - @pytest.mark.filterwarnings("ignore:ResponseError") - def test_get_moveable_keys(self, r): - commands_parser = CommandsParser(r) - args1 = ['EVAL', 'return {KEYS[1],KEYS[2],ARGV[1],ARGV[2]}', 2, 'key1', - 'key2', 'first', 'second'] - args2 = ['XREAD', 'COUNT', 2, b'STREAMS', 'mystream', 'writers', 0, 0] - args3 = ['ZUNIONSTORE', 'out', 2, 'zset1', 'zset2', 'WEIGHTS', 2, 3] - args4 = ['GEORADIUS', 'Sicily', 15, 37, 200, 'km', 'WITHCOORD', - b'STORE', 'out'] - args5 = ['MEMORY USAGE', 'foo'] - args6 = ['MIGRATE', '192.168.1.34', 6379, "", 0, 5000, b'KEYS', - 'key1', 'key2', 'key3'] - args7 = ['MIGRATE', '192.168.1.34', 6379, "key1", 0, 5000] - args8 = ['STRALGO', 'LCS', 'STRINGS', 'string_a', 'string_b'] - args9 = ['STRALGO', 'LCS', 'KEYS', 'key1', 'key2'] - - assert commands_parser.get_keys( - r, *args1).sort() == ['key1', 'key2'].sort() - assert commands_parser.get_keys( - r, *args2).sort() == ['mystream', 'writers'].sort() - assert commands_parser.get_keys( - r, *args3).sort() == ['out', 'zset1', 'zset2'].sort() - assert commands_parser.get_keys( - r, *args4).sort() == ['Sicily', 'out'].sort() - assert commands_parser.get_keys(r, *args5).sort() == ['foo'].sort() - assert commands_parser.get_keys( - r, *args6).sort() == ['key1', 'key2', 'key3'].sort() - assert commands_parser.get_keys(r, *args7).sort() == ['key1'].sort() - assert commands_parser.get_keys(r, *args8) is None - assert commands_parser.get_keys( - r, *args9).sort() == ['key1', 'key2'].sort() - - def test_get_pubsub_keys(self, r): - commands_parser = CommandsParser(r) - args1 = ['PUBLISH', 'foo', 'bar'] - args2 = ['PUBSUB NUMSUB', 'foo1', 'foo2', 'foo3'] - args3 = ['PUBSUB channels', '*'] - args4 = ['SUBSCRIBE', 'foo1', 'foo2', 'foo3'] - assert commands_parser.get_keys(r, *args1) == ['foo'] - assert commands_parser.get_keys(r, *args2) == ['foo1', 'foo2', 'foo3'] - assert commands_parser.get_keys(r, *args3) == ['*'] - assert commands_parser.get_keys(r, *args4) == ['foo1', 'foo2', 'foo3'] diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 30ca9f1124..4c754b0a2b 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -477,6 +477,7 @@ def test_channel_subscribe(self, r): class TestPubSubSubcommands: + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.0') def test_pubsub_channels(self, r): p = r.pubsub() @@ -486,6 +487,7 @@ def test_pubsub_channels(self, r): expected = [b'bar', b'baz', b'foo', b'quux'] assert all([channel in r.pubsub_channels() for channel in expected]) + @pytest.mark.onlynoncluster @skip_if_server_version_lt('2.8.0') def test_pubsub_numsub(self, r): p1 = r.pubsub() From 8124a2b31c220bd4d6dea465fcf9680909df1ee5 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Mon, 15 Nov 2021 18:56:32 +0200 Subject: [PATCH 14/22] Added ClusterPipeline documentation and tests --- README.md | 38 ++++ redis/cluster.py | 42 ++++- tests/test_cluster.py | 415 +++++++++++++++++++++++++++++++++++++++++- tox.ini | 1 - 4 files changed, 475 insertions(+), 21 deletions(-) diff --git a/README.md b/README.md index 9274aa6bb3..6e7d964b2b 100644 --- a/README.md +++ b/README.md @@ -1171,6 +1171,44 @@ readwrite() method. >>> rc_readonly.get('{foo}1') ``` +**Cluster Pipeline** + +ClusterPipeline is a subclass of RedisCluster that provides support for Redis +pipelines in cluster mode. +When calling the execute() command, all the commands are grouped by the node +on which they will be executed, and are then executed by the respective nodes +in parallel. The pipeline instance will wait for all the nodes to respond +before returning the result to the caller. Command responses are returned as a +list sorted in the same order in which they were sent. +Pipelines can be used to dramatically increase the throughput of Redis Cluster +by significantly reducing the the number of network round trips between the +client and the server. + +``` pycon + >>> with rc.pipeline() as pipe: + >>> pipe.set('foo', 'value1') + >>> pipe.set('bar', 'value2') + >>> pipe.get('foo') + >>> pipe.get('bar') + >>> print(pipe.execute()) + [True, True, b'value1', b'value2'] + >>> pipe.set('foo1', 'bar1').get('foo1').execute() + [True, b'bar1'] +``` +Please note: +- RedisCluster pipelines currently only support key-based commands. +- The pipeline gets its 'read_from_replicas' value from the cluster's parameter. +Thus, if read from replications is enabled in the cluster instance, the pipeline +will also direct read commands to replicas. +- The 'transcation' option is NOT supported in cluster-mode. In non-cluster mode, +the 'transaction' option is available when executing pipelines. This wraps the +pipeline commands with MULTI/EXEC commands, and effectively turns the pipeline +commands into a single transaction block. This means that all commands are +executed sequentially without any interruptions from other clients. However, +in cluster-mode this is not possible, because commands are partitioned +according to their respective destination nodes. This means that we can not +turn the pipeline commands into one transaction block, because in most cases +they are split up into several smaller pipelines. See [Redis Cluster tutorial](https://redis.io/topics/cluster-tutorial) and diff --git a/redis/cluster.py b/redis/cluster.py index ee40acfdae..ff2714dc5d 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -592,7 +592,9 @@ def get_nodes(self): def get_node_from_key(self, key, replica=False): """ - Get the node that holds the key's slot + Get the node that holds the key's slot. + If replica set to True but the slot doesn't have any replicas, None is + returned. """ slot = self.keyslot(key) slot_cache = self.nodes_manager.slots_cache.get(slot) @@ -600,9 +602,13 @@ def get_node_from_key(self, key, replica=False): raise SlotNotCoveredError( 'Slot "{0}" is not covered by the cluster.'.format(slot) ) - node_idx = 0 - if replica and len(self.nodes_manager.slots_cache[slot]) > 1: + if replica and len(self.nodes_manager.slots_cache[slot]) < 2: + return None + elif replica: node_idx = 1 + else: + # primary + node_idx = 0 return slot_cache[node_idx] @@ -633,8 +639,7 @@ def pubsub(self, node=None, host=None, port=None, **kwargs): """ return ClusterPubSub(self, node=node, host=host, port=port, **kwargs) - def pipeline(self, transaction=None, - shard_hint=None, read_from_replicas=False): + def pipeline(self, transaction=None, shard_hint=None): """ Cluster impl: Pipelines do not work in cluster mode the same way they @@ -657,7 +662,8 @@ def pipeline(self, transaction=None, result_callbacks=self.result_callbacks, cluster_response_callbacks=self.cluster_response_callbacks, cluster_error_retry_attempts=self.cluster_error_retry_attempts, - read_from_replicas=read_from_replicas, + read_from_replicas=self.read_from_replicas, + reinitialize_steps=self.reinitialize_steps ) def _determine_nodes(self, *args, **kwargs): @@ -1235,6 +1241,7 @@ def initialize(self): :startup_nodes: Responsible for discovering other nodes in the cluster """ + log.debug("Initializing the nodes' topology of the cluster") self.reset() tmp_nodes_cache = {} tmp_slots = {} @@ -1546,7 +1553,7 @@ class ClusterPipeline(RedisCluster): def __init__(self, nodes_manager, result_callbacks=None, cluster_response_callbacks=None, startup_nodes=None, read_from_replicas=False, cluster_error_retry_attempts=3, - **kwargs): + reinitialize_steps=10, **kwargs): """ """ log.info("Creating new instance of ClusterPipeline") @@ -1560,7 +1567,8 @@ def __init__(self, nodes_manager, result_callbacks=None, self.command_flags = self.__class__.COMMAND_FLAGS.copy() self.cluster_response_callbacks = cluster_response_callbacks self.cluster_error_retry_attempts = cluster_error_retry_attempts - + self.reinitialize_counter = 0 + self.reinitialize_steps = reinitialize_steps self.encoder = Encoder( kwargs.get("encoding", "utf-8"), kwargs.get("encoding_errors", "strict"), @@ -1825,8 +1833,15 @@ def _send_cluster_commands(self, stack, # If a lot of commands have failed, we'll be setting the # flag to rebuild the slots table from scratch. # So MOVED errors should correct themselves fairly quickly. - self.connection_pool.nodes. \ - increment_reinitialize_counter(len(attempt)) + msg = 'An exception occurred during pipeline execution. ' \ + 'args: {0}, error: {1} {2}'.\ + format(attempt[-1].args, + type(attempt[-1].result).__name__, + str(attempt[-1].result)) + log.exception(msg) + self.reinitialize_counter += 1 + if self._should_reinitialized(): + self.nodes_manager.initialize() for c in attempt: try: # send each command individually like we @@ -1852,6 +1867,11 @@ def _fail_on_redirect(self, allow_redirections): raise RedisClusterException( "ASK & MOVED redirection not allowed in this pipeline") + def eval(self): + """ + """ + raise RedisClusterException("method eval() is not implemented") + def multi(self): """ """ @@ -1951,6 +1971,8 @@ def inner(*args, **kwargs): ClusterPipeline.sort = block_pipeline_command(RedisCluster.sort) ClusterPipeline.sunion = block_pipeline_command(RedisCluster.sunion) ClusterPipeline.sunionstore = block_pipeline_command(RedisCluster.sunionstore) +ClusterPipeline.readwrite = block_pipeline_command(RedisCluster.readwrite) +ClusterPipeline.readonly = block_pipeline_command(RedisCluster.readonly) class PipelineCommand(object): diff --git a/tests/test_cluster.py b/tests/test_cluster.py index 3573fffafa..ced40aa4ac 100644 --- a/tests/test_cluster.py +++ b/tests/test_cluster.py @@ -117,6 +117,13 @@ def mock_node_resp(node, response): return node +def mock_node_resp_func(node, func): + connection = Mock() + connection.read_response.side_effect = func + node.redis_connection.connection = connection + return node + + def mock_all_nodes_resp(rc, response): for node in rc.get_nodes(): mock_node_resp(node, response) @@ -307,7 +314,6 @@ def test_execute_command_default_node(self, r): conn = def_node.redis_connection.connection assert conn.read_response.called - @pytest.mark.filterwarnings("ignore:AskError") def test_ask_redirection(self, r): """ Test that the server handles ASK response. @@ -334,21 +340,18 @@ def ok_response(connection, *args, **options): assert r.execute_command("SET", "foo", "bar") == "MOCK_OK" - @pytest.mark.filterwarnings("ignore:MovedError") def test_moved_redirection(self, request): """ Test that the client handles MOVED response. """ moved_redirection_helper(request, failover=False) - @pytest.mark.filterwarnings("ignore:MovedError") def test_moved_redirection_after_failover(self, request): """ Test that the client handles MOVED response after a failover. """ moved_redirection_helper(request, failover=True) - @pytest.mark.filterwarnings("ignore:ClusterDownError") def test_refresh_using_specific_nodes(self, request): """ Test making calls on specific nodes when the cluster has failed over to @@ -526,7 +529,6 @@ def test_all_nodes_masters(self, r): for node in r.get_primaries(): assert node in nodes - @pytest.mark.filterwarnings("ignore:ClusterDownError") def test_cluster_down_overreaches_retry_attempts(self): """ When ClusterDownError is thrown, test that we retry executing the @@ -549,7 +551,6 @@ def raise_cluster_down_error(target_node, *args, **kwargs): assert execute_command.failed_calls == \ rc.cluster_error_retry_attempts - @pytest.mark.filterwarnings("ignore:ConnectionError") def test_connection_error_overreaches_retry_attempts(self): """ When ConnectionError is thrown, test that we retry executing the @@ -616,10 +617,10 @@ def test_get_node_from_key(self, r): slot_nodes = r.nodes_manager.slots_cache.get(slot) primary = slot_nodes[0] assert r.get_node_from_key(key, replica=False) == primary - if len(slot_nodes) > 1: - key_node = r.get_node_from_key(key, replica=True) - assert key_node.server_type == 'replica' - assert key_node in slot_nodes + replica = r.get_node_from_key(key, replica=True) + if replica is not None: + assert replica.server_type == REPLICA + assert replica in slot_nodes @pytest.mark.onlycluster @@ -2087,3 +2088,397 @@ def test_get_redis_connection(self, r): node = r.get_default_node() p = r.pubsub(node=node) assert p.get_redis_connection() == node.redis_connection + + +@pytest.mark.onlycluster +class TestClusterPipeline: + """ + Tests for the ClusterPipeline class + """ + + def test_blocked_methods(self, r): + """ + Currently some method calls on a Cluster pipeline + is blocked when using in cluster mode. + They maybe implemented in the future. + """ + pipe = r.pipeline() + with pytest.raises(RedisClusterException): + pipe.multi() + + with pytest.raises(RedisClusterException): + pipe.immediate_execute_command() + + with pytest.raises(RedisClusterException): + pipe._execute_transaction(None, None, None) + + with pytest.raises(RedisClusterException): + pipe.load_scripts() + + with pytest.raises(RedisClusterException): + pipe.watch() + + with pytest.raises(RedisClusterException): + pipe.unwatch() + + with pytest.raises(RedisClusterException): + pipe.script_load_for_pipeline(None) + + with pytest.raises(RedisClusterException): + pipe.eval() + + def test_blocked_arguments(self, r): + """ + Currently some arguments is blocked when using in cluster mode. + They maybe implemented in the future. + """ + with pytest.raises(RedisClusterException) as ex: + r.pipeline(transaction=True) + + assert str(ex.value).startswith( + "transaction is deprecated in cluster mode") is True + + with pytest.raises(RedisClusterException) as ex: + r.pipeline(shard_hint=True) + + assert str(ex.value).startswith( + "shard_hint is deprecated in cluster mode") is True + + def test_redis_cluster_pipeline(self, r): + """ + Test that we can use a pipeline with the RedisCluster class + """ + with r.pipeline() as pipe: + pipe.set("foo", "bar") + pipe.get("foo") + assert pipe.execute() == [True, b'bar'] + + def test_mget_disabled(self, r): + """ + Test that mget is disabled for ClusterPipeline + """ + with r.pipeline() as pipe: + with pytest.raises(RedisClusterException): + pipe.mget(['a']) + + def test_mset_disabled(self, r): + """ + Test that mset is disabled for ClusterPipeline + """ + with r.pipeline() as pipe: + with pytest.raises(RedisClusterException): + pipe.mset({'a': 1, 'b': 2}) + + def test_rename_disabled(self, r): + """ + Test that rename is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.rename('a', 'b') + + def test_renamenx_disabled(self, r): + """ + Test that renamenx is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.renamenx('a', 'b') + + def test_delete_single(self, r): + """ + Test a single delete operation + """ + r['a'] = 1 + with r.pipeline(transaction=False) as pipe: + pipe.delete('a') + assert pipe.execute() == [1] + + def test_multi_delete_unsupported(self, r): + """ + Test that multi delete operation is unsupported + """ + with r.pipeline(transaction=False) as pipe: + r['a'] = 1 + r['b'] = 2 + with pytest.raises(RedisClusterException): + pipe.delete('a', 'b') + + def test_brpoplpush_disabled(self, r): + """ + Test that brpoplpush is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.brpoplpush() + + def test_rpoplpush_disabled(self, r): + """ + Test that rpoplpush is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.rpoplpush() + + def test_sort_disabled(self, r): + """ + Test that sort is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sort() + + def test_sdiff_disabled(self, r): + """ + Test that sdiff is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sdiff() + + def test_sdiffstore_disabled(self, r): + """ + Test that sdiffstore is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sdiffstore() + + def test_sinter_disabled(self, r): + """ + Test that sinter is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sinter() + + def test_sinterstore_disabled(self, r): + """ + Test that sinterstore is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sinterstore() + + def test_smove_disabled(self, r): + """ + Test that move is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.smove() + + def test_sunion_disabled(self, r): + """ + Test that sunion is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sunion() + + def test_sunionstore_disabled(self, r): + """ + Test that sunionstore is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.sunionstore() + + def test_spfmerge_disabled(self, r): + """ + Test that spfmerge is disabled for ClusterPipeline + """ + with r.pipeline(transaction=False) as pipe: + with pytest.raises(RedisClusterException): + pipe.pfmerge() + + def test_multi_key_operation_with_a_single_slot(self, r): + """ + Test multi key operation with a single slot + """ + pipe = r.pipeline(transaction=False) + pipe.set('a{foo}', 1) + pipe.set('b{foo}', 2) + pipe.set('c{foo}', 3) + pipe.get('a{foo}') + pipe.get('b{foo}') + pipe.get('c{foo}') + + res = pipe.execute() + assert res == [True, True, True, b'1', b'2', b'3'] + + def test_multi_key_operation_with_multi_slots(self, r): + """ + Test multi key operation with more than one slot + """ + pipe = r.pipeline(transaction=False) + pipe.set('a{foo}', 1) + pipe.set('b{foo}', 2) + pipe.set('c{foo}', 3) + pipe.set('bar', 4) + pipe.set('bazz', 5) + pipe.get('a{foo}') + pipe.get('b{foo}') + pipe.get('c{foo}') + pipe.get('bar') + pipe.get('bazz') + res = pipe.execute() + assert res == [True, True, True, True, True, b'1', b'2', b'3', b'4', + b'5'] + + def test_connection_error_not_raised(self, r): + """ + Test that the pipeline doesn't raise an error on connection error when + raise_on_error=False + """ + key = 'foo' + node = r.get_node_from_key(key, False) + + def raise_connection_error(): + e = ConnectionError("error") + return e + + with r.pipeline() as pipe: + mock_node_resp_func(node, raise_connection_error) + res = pipe.get(key).get(key).execute(raise_on_error=False) + assert node.redis_connection.connection.read_response.called + assert isinstance(res[0], ConnectionError) + + def test_connection_error_raised(self, r): + """ + Test that the pipeline raises an error on connection error when + raise_on_error=True + """ + key = 'foo' + node = r.get_node_from_key(key, False) + + def raise_connection_error(): + e = ConnectionError("error") + return e + + with r.pipeline() as pipe: + mock_node_resp_func(node, raise_connection_error) + with pytest.raises(ConnectionError): + pipe.get(key).get(key).execute(raise_on_error=True) + + def test_asking_error(self, r): + """ + Test redirection on ASK error + """ + key = 'foo' + first_node = r.get_node_from_key(key, False) + ask_node = None + for node in r.get_nodes(): + if node != first_node: + ask_node = node + break + if ask_node is None: + warnings.warn("skipping this test since the cluster has only one " + "node") + return + ask_msg = "{0} {1}:{2}".format(r.keyslot(key), ask_node.host, + ask_node.port) + + def raise_ask_error(): + raise AskError(ask_msg) + + with r.pipeline() as pipe: + mock_node_resp_func(first_node, raise_ask_error) + mock_node_resp(ask_node, 'MOCK_OK') + res = pipe.get(key).execute() + assert first_node.redis_connection.connection.read_response.called + assert ask_node.redis_connection.connection.read_response.called + assert res == ['MOCK_OK'] + + def test_empty_stack(self, r): + """ + If pipeline is executed with no commands it should + return a empty list. + """ + p = r.pipeline() + result = p.execute() + assert result == [] + + +@pytest.mark.onlycluster +class TestReadOnlyPipeline: + """ + Tests for ClusterPipeline class in readonly mode + """ + + def test_pipeline_readonly(self, r): + """ + On readonly mode, we supports get related stuff only. + """ + r.readonly(target_nodes='all') + r.set('foo71', 'a1') # we assume this key is set on 127.0.0.1:7001 + r.zadd('foo88', + {'z1': 1}) # we assume this key is set on 127.0.0.1:7002 + r.zadd('foo88', {'z2': 4}) + + with r.pipeline() as readonly_pipe: + readonly_pipe.get('foo71').zrange('foo88', 0, 5, withscores=True) + assert readonly_pipe.execute() == [ + b'a1', + [(b'z1', 1.0), (b'z2', 4)], + ] + + def test_moved_redirection_on_slave_with_default(self, r): + """ + On Pipeline, we redirected once and finally get from master with + readonly client when data is completely moved. + """ + key = 'bar' + r.set(key, 'foo') + # set read_from_replicas to True + r.read_from_replicas = True + primary = r.get_node_from_key(key, False) + replica = r.get_node_from_key(key, True) + with r.pipeline() as readwrite_pipe: + mock_node_resp(primary, "MOCK_FOO") + if replica is not None: + moved_error = "{0} {1}:{2}".format(r.keyslot(key), + primary.host, + primary.port) + + def raise_moved_error(): + raise MovedError(moved_error) + + mock_node_resp_func(replica, raise_moved_error) + assert readwrite_pipe.reinitialize_counter == 0 + readwrite_pipe.get(key).get(key) + assert readwrite_pipe.execute() == ["MOCK_FOO", "MOCK_FOO"] + if replica is not None: + # the slot has a replica as well, so MovedError should have + # occurred. If MovedError occurs, we should see the + # reinitialize_counter increase. + assert readwrite_pipe.reinitialize_counter == 1 + conn = replica.redis_connection.connection + assert conn.read_response.called is True + + def test_readonly_pipeline_from_readonly_client(self, request): + """ + Test that the pipeline is initialized with readonly mode if the client + has it enabled + """ + # Create a cluster with reading from replications + ro = _get_client(RedisCluster, request, read_from_replicas=True) + key = 'bar' + ro.set(key, 'foo') + import time + time.sleep(0.2) + with ro.pipeline() as readonly_pipe: + mock_all_nodes_resp(ro, 'MOCK_OK') + assert readonly_pipe.read_from_replicas is True + assert readonly_pipe.get(key).get( + key).execute() == ['MOCK_OK', 'MOCK_OK'] + slot_nodes = ro.nodes_manager.slots_cache[ro.keyslot(key)] + if len(slot_nodes) > 1: + executed_on_replica = False + for node in slot_nodes: + if node.server_type == REPLICA: + conn = node.redis_connection.connection + executed_on_replica = conn.read_response.called + if executed_on_replica: + break + assert executed_on_replica is True diff --git a/tox.ini b/tox.ini index bfcee9617c..8abd6399a7 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,6 @@ markers = redismod: run only the redis module tests onlycluster: marks tests to be run only with cluster mode redis onlynoncluster: marks tests to be run only with non-cluster redis - pipeline: [tox] minversion = 3.2.0 From 608949c59941dfa933cd04b7665d330b4cc94a16 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Tue, 16 Nov 2021 17:48:55 +0200 Subject: [PATCH 15/22] fixed install_and_test.sh pytest command to include markers and cluster mode tests --- .github/workflows/install_and_test.sh | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/install_and_test.sh b/.github/workflows/install_and_test.sh index 330102eb41..7a9701f18d 100755 --- a/.github/workflows/install_and_test.sh +++ b/.github/workflows/install_and_test.sh @@ -38,4 +38,8 @@ cd ${TESTDIR} # install, run tests pip install ${PKG} -pytest +# Redis tests +pytest -m 'not onlycluster and not redismod' +# RedisCluster tests +CLUSTER_URL="redis://localhost:16379/0" +pytest -m 'not onlynoncluster and not redismod' --redis-url=${CLUSTER_URL} From 85aff46b34dd8d5253d461290939da9265a2720c Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 17 Nov 2021 11:22:15 +0200 Subject: [PATCH 16/22] Added support for RedisCluster to pass redis URL without a port and setting the port to the default value (6379). e.g. "redis://localhost" will be parsed to host=localhost, port=6379 --- redis/cluster.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/redis/cluster.py b/redis/cluster.py index ff2714dc5d..8a74a23836 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -436,7 +436,9 @@ def __init__( "A ``db`` querystring option can only be 0 in cluster mode" ) kwargs.update(url_options) - startup_nodes.append(ClusterNode(kwargs['host'], kwargs['port'])) + host = kwargs.get('host') + port = kwargs.get('port', port) + startup_nodes.append(ClusterNode(host, port)) elif host is not None and port is not None: startup_nodes.append(ClusterNode(host, port)) elif len(startup_nodes) == 0: From a632159ca43368363fc7945b42020c33b6f9be96 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 17 Nov 2021 13:15:35 +0200 Subject: [PATCH 17/22] Added ignore test files to the codecov configurations --- codecov.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/codecov.yml b/codecov.yml index 449ec0c50f..b6c1e6a90f 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,6 +1,7 @@ ignore: - "benchmarks/**" - "tasks.py" + - "test_*.py" codecov: require_ci_to_pass: yes From 854a065deff21e8a514f40d232466c2aeeae13ee Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Wed, 17 Nov 2021 15:18:46 +0200 Subject: [PATCH 18/22] Fixed codecov ignore tests configuration --- codecov.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/codecov.yml b/codecov.yml index b6c1e6a90f..b1cbc68cb0 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,7 +1,7 @@ ignore: - "benchmarks/**" - "tasks.py" - - "test_*.py" + - "tests/test_*.py" codecov: require_ci_to_pass: yes From d6d2d29170d23210dda4b4ed99fa1d4b3dc8e3b3 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Thu, 18 Nov 2021 15:46:57 +0200 Subject: [PATCH 19/22] Added coverage unique names to the Redis's and ClusterRedis's codecov reports so it won't get override. --- tox.ini | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index 8abd6399a7..c68dab18ea 100644 --- a/tox.ini +++ b/tox.ini @@ -108,8 +108,8 @@ extras = setenv = CLUSTER_URL = "redis://localhost:16379/0" commands = - redis: pytest --cov=./ --cov-report=xml -W always -m 'not onlycluster and not redismod' {posargs} - cluster: pytest --cov=./ --cov-report=xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} {posargs} + redis: pytest --cov=./ --cov-report=xml:coverage_redis.xml -W always -m 'not onlycluster' {posargs} + cluster: pytest --cov=./ --cov-report=xml:coverage_cluster.xml -W always -m 'not onlynoncluster and not redismod' --redis-url={env:CLUSTER_URL:} {posargs} [testenv:devenv] skipsdist = true From c0226b0f2479fe1d29b3418aab4f70c3fe7bdbd8 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Thu, 18 Nov 2021 16:25:52 +0200 Subject: [PATCH 20/22] Rolled back test_sentinel.py changes. Removed the tests folder from the codecov ignore section. removed the 'non redismod' marker from the pytest run in install_and_test.sh --- .github/workflows/install_and_test.sh | 2 +- codecov.yml | 1 - tests/test_sentinel.py | 312 ++++++++++++++------------ 3 files changed, 169 insertions(+), 146 deletions(-) diff --git a/.github/workflows/install_and_test.sh b/.github/workflows/install_and_test.sh index 7a9701f18d..7a8cd672fd 100755 --- a/.github/workflows/install_and_test.sh +++ b/.github/workflows/install_and_test.sh @@ -39,7 +39,7 @@ cd ${TESTDIR} # install, run tests pip install ${PKG} # Redis tests -pytest -m 'not onlycluster and not redismod' +pytest -m 'not onlycluster' # RedisCluster tests CLUSTER_URL="redis://localhost:16379/0" pytest -m 'not onlynoncluster and not redismod' --redis-url=${CLUSTER_URL} diff --git a/codecov.yml b/codecov.yml index b1cbc68cb0..449ec0c50f 100644 --- a/codecov.yml +++ b/codecov.yml @@ -1,7 +1,6 @@ ignore: - "benchmarks/**" - "tasks.py" - - "tests/test_*.py" codecov: require_ci_to_pass: yes diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py index 0200cb6cd1..9377d5ba65 100644 --- a/tests/test_sentinel.py +++ b/tests/test_sentinel.py @@ -36,24 +36,6 @@ def execute_command(self, *args, **kwargs): return bool_ok -@pytest.fixture() -def cluster(request, master_ip): - def teardown(): - redis.sentinel.Redis = saved_Redis - - cluster = SentinelTestCluster(ip=master_ip) - saved_Redis = redis.sentinel.Redis - redis.sentinel.Redis = cluster.client - request.addfinalizer(teardown) - return cluster - - -@pytest.fixture() -def sentinel(request, cluster): - return Sentinel([('foo', 26379), ('bar', 26379)]) - - -@pytest.mark.onlynoncluster class SentinelTestCluster: def __init__(self, servisentinel_ce_name='mymaster', ip='127.0.0.1', port=6379): @@ -82,129 +64,171 @@ def timeout_if_down(self, node): def client(self, host, port, **kwargs): return SentinelTestClient(self, (host, port)) - def test_discover_master(sentinel, master_ip): - address = sentinel.discover_master('mymaster') - assert address == (master_ip, 6379) - - def test_discover_master_error(sentinel): - with pytest.raises(MasterNotFoundError): - sentinel.discover_master('xxx') - - def test_discover_master_sentinel_down(cluster, sentinel, master_ip): - # Put first sentinel 'foo' down - cluster.nodes_down.add(('foo', 26379)) - address = sentinel.discover_master('mymaster') - assert address == (master_ip, 6379) - # 'bar' is now first sentinel - assert sentinel.sentinels[0].id == ('bar', 26379) - - def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip): - # Put first sentinel 'foo' down - cluster.nodes_timeout.add(('foo', 26379)) - address = sentinel.discover_master('mymaster') - assert address == (master_ip, 6379) - # 'bar' is now first sentinel - assert sentinel.sentinels[0].id == ('bar', 26379) - - def test_master_min_other_sentinels(cluster, master_ip): - sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1) - # min_other_sentinels - with pytest.raises(MasterNotFoundError): - sentinel.discover_master('mymaster') - cluster.master['num-other-sentinels'] = 2 - address = sentinel.discover_master('mymaster') - assert address == (master_ip, 6379) - - def test_master_odown(cluster, sentinel): - cluster.master['is_odown'] = True - with pytest.raises(MasterNotFoundError): - sentinel.discover_master('mymaster') - - def test_master_sdown(cluster, sentinel): - cluster.master['is_sdown'] = True - with pytest.raises(MasterNotFoundError): - sentinel.discover_master('mymaster') - - def test_discover_slaves(cluster, sentinel): - assert sentinel.discover_slaves('mymaster') == [] - - cluster.slaves = [ - {'ip': 'slave0', 'port': 1234, 'is_odown': False, - 'is_sdown': False}, - {'ip': 'slave1', 'port': 1234, 'is_odown': False, - 'is_sdown': False}, - ] - assert sentinel.discover_slaves('mymaster') == [ - ('slave0', 1234), ('slave1', 1234)] - - # slave0 -> ODOWN - cluster.slaves[0]['is_odown'] = True - assert sentinel.discover_slaves('mymaster') == [ - ('slave1', 1234)] - - # slave1 -> SDOWN - cluster.slaves[1]['is_sdown'] = True - assert sentinel.discover_slaves('mymaster') == [] - - cluster.slaves[0]['is_odown'] = False - cluster.slaves[1]['is_sdown'] = False - - # node0 -> DOWN - cluster.nodes_down.add(('foo', 26379)) - assert sentinel.discover_slaves('mymaster') == [ - ('slave0', 1234), ('slave1', 1234)] - cluster.nodes_down.clear() - - # node0 -> TIMEOUT - cluster.nodes_timeout.add(('foo', 26379)) - assert sentinel.discover_slaves('mymaster') == [ - ('slave0', 1234), ('slave1', 1234)] - - def test_master_for(cluster, sentinel, master_ip): - master = sentinel.master_for('mymaster', db=9) - assert master.ping() - assert master.connection_pool.master_address == (master_ip, 6379) - - # Use internal connection check - master = sentinel.master_for('mymaster', db=9, check_connection=True) - assert master.ping() - - def test_slave_for(cluster, sentinel): - cluster.slaves = [ - {'ip': '127.0.0.1', 'port': 6379, - 'is_odown': False, 'is_sdown': False}, - ] - slave = sentinel.slave_for('mymaster', db=9) - assert slave.ping() - - def test_slave_for_slave_not_found_error(cluster, sentinel): - cluster.master['is_odown'] = True - slave = sentinel.slave_for('mymaster', db=9) - with pytest.raises(SlaveNotFoundError): - slave.ping() - - def test_slave_round_robin(cluster, sentinel, master_ip): - cluster.slaves = [ - {'ip': 'slave0', 'port': 6379, 'is_odown': False, - 'is_sdown': False}, - {'ip': 'slave1', 'port': 6379, 'is_odown': False, - 'is_sdown': False}, - ] - pool = SentinelConnectionPool('mymaster', sentinel) - rotator = pool.rotate_slaves() - assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) - assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) - # Fallback to master - assert next(rotator) == (master_ip, 6379) - with pytest.raises(SlaveNotFoundError): - next(rotator) - - def test_ckquorum(cluster, sentinel): - assert sentinel.sentinel_ckquorum("mymaster") - - def test_flushconfig(cluster, sentinel): - assert sentinel.sentinel_flushconfig() - - def test_reset(cluster, sentinel): - cluster.master['is_odown'] = True - assert sentinel.sentinel_reset('mymaster') + +@pytest.fixture() +def cluster(request, master_ip): + def teardown(): + redis.sentinel.Redis = saved_Redis + cluster = SentinelTestCluster(ip=master_ip) + saved_Redis = redis.sentinel.Redis + redis.sentinel.Redis = cluster.client + request.addfinalizer(teardown) + return cluster + + +@pytest.fixture() +def sentinel(request, cluster): + return Sentinel([('foo', 26379), ('bar', 26379)]) + + +@pytest.mark.onlynoncluster +def test_discover_master(sentinel, master_ip): + address = sentinel.discover_master('mymaster') + assert address == (master_ip, 6379) + + +@pytest.mark.onlynoncluster +def test_discover_master_error(sentinel): + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('xxx') + + +@pytest.mark.onlynoncluster +def test_discover_master_sentinel_down(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_down.add(('foo', 26379)) + address = sentinel.discover_master('mymaster') + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ('bar', 26379) + + +@pytest.mark.onlynoncluster +def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_timeout.add(('foo', 26379)) + address = sentinel.discover_master('mymaster') + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ('bar', 26379) + + +@pytest.mark.onlynoncluster +def test_master_min_other_sentinels(cluster, master_ip): + sentinel = Sentinel([('foo', 26379)], min_other_sentinels=1) + # min_other_sentinels + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('mymaster') + cluster.master['num-other-sentinels'] = 2 + address = sentinel.discover_master('mymaster') + assert address == (master_ip, 6379) + + +@pytest.mark.onlynoncluster +def test_master_odown(cluster, sentinel): + cluster.master['is_odown'] = True + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('mymaster') + + +@pytest.mark.onlynoncluster +def test_master_sdown(cluster, sentinel): + cluster.master['is_sdown'] = True + with pytest.raises(MasterNotFoundError): + sentinel.discover_master('mymaster') + + +@pytest.mark.onlynoncluster +def test_discover_slaves(cluster, sentinel): + assert sentinel.discover_slaves('mymaster') == [] + + cluster.slaves = [ + {'ip': 'slave0', 'port': 1234, 'is_odown': False, 'is_sdown': False}, + {'ip': 'slave1', 'port': 1234, 'is_odown': False, 'is_sdown': False}, + ] + assert sentinel.discover_slaves('mymaster') == [ + ('slave0', 1234), ('slave1', 1234)] + + # slave0 -> ODOWN + cluster.slaves[0]['is_odown'] = True + assert sentinel.discover_slaves('mymaster') == [ + ('slave1', 1234)] + + # slave1 -> SDOWN + cluster.slaves[1]['is_sdown'] = True + assert sentinel.discover_slaves('mymaster') == [] + + cluster.slaves[0]['is_odown'] = False + cluster.slaves[1]['is_sdown'] = False + + # node0 -> DOWN + cluster.nodes_down.add(('foo', 26379)) + assert sentinel.discover_slaves('mymaster') == [ + ('slave0', 1234), ('slave1', 1234)] + cluster.nodes_down.clear() + + # node0 -> TIMEOUT + cluster.nodes_timeout.add(('foo', 26379)) + assert sentinel.discover_slaves('mymaster') == [ + ('slave0', 1234), ('slave1', 1234)] + + +@pytest.mark.onlynoncluster +def test_master_for(cluster, sentinel, master_ip): + master = sentinel.master_for('mymaster', db=9) + assert master.ping() + assert master.connection_pool.master_address == (master_ip, 6379) + + # Use internal connection check + master = sentinel.master_for('mymaster', db=9, check_connection=True) + assert master.ping() + + +@pytest.mark.onlynoncluster +def test_slave_for(cluster, sentinel): + cluster.slaves = [ + {'ip': '127.0.0.1', 'port': 6379, + 'is_odown': False, 'is_sdown': False}, + ] + slave = sentinel.slave_for('mymaster', db=9) + assert slave.ping() + + +@pytest.mark.onlynoncluster +def test_slave_for_slave_not_found_error(cluster, sentinel): + cluster.master['is_odown'] = True + slave = sentinel.slave_for('mymaster', db=9) + with pytest.raises(SlaveNotFoundError): + slave.ping() + + +@pytest.mark.onlynoncluster +def test_slave_round_robin(cluster, sentinel, master_ip): + cluster.slaves = [ + {'ip': 'slave0', 'port': 6379, 'is_odown': False, 'is_sdown': False}, + {'ip': 'slave1', 'port': 6379, 'is_odown': False, 'is_sdown': False}, + ] + pool = SentinelConnectionPool('mymaster', sentinel) + rotator = pool.rotate_slaves() + assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) + assert next(rotator) in (('slave0', 6379), ('slave1', 6379)) + # Fallback to master + assert next(rotator) == (master_ip, 6379) + with pytest.raises(SlaveNotFoundError): + next(rotator) + + +@pytest.mark.onlynoncluster +def test_ckquorum(cluster, sentinel): + assert sentinel.sentinel_ckquorum("mymaster") + + +@pytest.mark.onlynoncluster +def test_flushconfig(cluster, sentinel): + assert sentinel.sentinel_flushconfig() + + +@pytest.mark.onlynoncluster +def test_reset(cluster, sentinel): + cluster.master['is_odown'] = True + assert sentinel.sentinel_reset('mymaster') From a2b022bca62de5360a7a338975ef85bce1153690 Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Sun, 21 Nov 2021 16:05:45 +0200 Subject: [PATCH 21/22] Changed the PubSub's health check command to be performed only on the first command execution. --- redis/client.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/redis/client.py b/redis/client.py index 4413b94996..ba12d86271 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1241,6 +1241,7 @@ def reset(self): self.pending_unsubscribe_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() + self.cmd_execution_health_check = True def close(self): self.reset() @@ -1284,8 +1285,11 @@ def execute_command(self, *args): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) connection = self.connection - kwargs = {'check_health': not self.subscribed} + kwargs = {'check_health': self.cmd_execution_health_check} self._execute(connection, connection.send_command, *args, **kwargs) + if self.cmd_execution_health_check is True: + # Run a health check only on the first command execution + self.cmd_execution_health_check = False def _disconnect_raise_connect(self, conn, error): """ @@ -1437,6 +1441,10 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0): before returning. Timeout should be specified as a floating point number. """ + if self.cmd_execution_health_check is True: + # Health checks will be done within the parse_response method, + # cancel health checks from the command_execution method + self.cmd_execution_health_check = False response = self.parse_response(block=False, timeout=timeout) if response: return self.handle_message(response, ignore_subscribe_messages) From 39f5cc0332bbcd7caa70c8b8b1cf8a15b6ebd9fd Mon Sep 17 00:00:00 2001 From: Bar Shaul Date: Mon, 22 Nov 2021 10:46:49 +0200 Subject: [PATCH 22/22] Changed get_message logic to wait for subscription --- redis/client.py | 31 +++++++++++++++++++++---------- tests/test_pubsub.py | 43 +++++++++++++++++++++++++++++++++++-------- 2 files changed, 56 insertions(+), 18 deletions(-) diff --git a/redis/client.py b/redis/client.py index ba12d86271..fe16b99491 100755 --- a/redis/client.py +++ b/redis/client.py @@ -1241,7 +1241,6 @@ def reset(self): self.pending_unsubscribe_channels = set() self.patterns = {} self.pending_unsubscribe_patterns = set() - self.cmd_execution_health_check = True def close(self): self.reset() @@ -1285,11 +1284,8 @@ def execute_command(self, *args): # were listening to when we were disconnected self.connection.register_connect_callback(self.on_connect) connection = self.connection - kwargs = {'check_health': self.cmd_execution_health_check} + kwargs = {'check_health': not self.subscribed} self._execute(connection, connection.send_command, *args, **kwargs) - if self.cmd_execution_health_check is True: - # Run a health check only on the first command execution - self.cmd_execution_health_check = False def _disconnect_raise_connect(self, conn, error): """ @@ -1440,16 +1436,31 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0): If timeout is specified, the system will wait for `timeout` seconds before returning. Timeout should be specified as a floating point number. + + if not self.subscribed and \ + self.wait_for_subscription(timeout) is False: + # The connection isn't subscribed to any channels or patterns, so + # no messages are available + return None """ - if self.cmd_execution_health_check is True: - # Health checks will be done within the parse_response method, - # cancel health checks from the command_execution method - self.cmd_execution_health_check = False response = self.parse_response(block=False, timeout=timeout) if response: return self.handle_message(response, ignore_subscribe_messages) return None + def wait_for_subscription(self, timeout, period=0.25): + """ + Wait until this pubsub connection has been subscribed. + Return True if the connection was subscribed during the timeout + frametime. Otherwise, return False. + """ + mustend = time.time() + timeout + while time.time() < mustend: + if self.subscribed: + return True + time.sleep(period) + return False + def ping(self, message=None): """ Ping the Redis server @@ -1721,7 +1732,7 @@ def pipeline_execute_command(self, *args, **options): return self def _execute_transaction(self, connection, commands, raise_on_error): - cmds = chain([(('MULTI', ), {})], commands, [(('EXEC', ), {})]) + cmds = chain([(('MULTI',), {})], commands, [(('EXEC',), {})]) all_cmds = connection.pack_commands([args for args, options in cmds if EMPTY_RESPONSE not in options]) connection.send_packed_command(all_cmds) diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index 95513a09a8..6b7581e3f1 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -1,10 +1,12 @@ import threading import time from unittest import mock +from unittest.mock import patch import platform import pytest import redis +from redis.client import PubSub from redis.exceptions import ConnectionError from .conftest import ( @@ -344,14 +346,6 @@ def test_unicode_pattern_message_handler(self, r): assert self.message == make_message('pmessage', channel, 'test message', pattern=pattern) - def test_get_message_without_subscribe(self, r): - p = r.pubsub() - with pytest.raises(RuntimeError) as info: - p.get_message() - expect = ('connection not set: ' - 'did you forget to call subscribe() or psubscribe()?') - assert expect in info.exconly() - class TestPubSubAutoDecoding: "These tests only validate that we get unicode values back" @@ -562,6 +556,39 @@ def test_get_message_with_timeout_returns_none(self, r): assert wait_for_message(p) == make_message('subscribe', 'foo', 1) assert p.get_message(timeout=0.01) is None + def test_get_message_not_subscribed_return_none(self, r): + p = r.pubsub() + assert p.subscribed is False + assert p.get_message() is None + assert p.get_message(timeout=0.1) is None + with patch.object(PubSub, 'wait_for_subscription') as mock: + mock.return_value = False + assert p.get_message(timeout=0.01) is None + assert mock.called + + def test_get_message_subscribe_during_waiting(self, r): + p = r.pubsub() + + def poll(ps, expected_res): + assert ps.get_message() is None + message = ps.get_message(timeout=1) + assert message == expected_res + + subscribe_response = make_message('subscribe', 'foo', 1) + poller = threading.Thread(target=poll, args=(p, subscribe_response)) + poller.start() + time.sleep(0.2) + p.subscribe('foo') + poller.join() + + def test_get_message_wait_for_subscription_not_being_called(self, r): + p = r.pubsub() + p.subscribe('foo') + with patch.object(PubSub, 'wait_for_subscription') as mock: + assert p.subscribed is True + assert wait_for_message(p) == make_message('subscribe', 'foo', 1) + assert mock.called is False + class TestPubSubWorkerThread: