diff --git a/redis/client.py b/redis/client.py index 59575cd835..ea36ef0ddc 100755 --- a/redis/client.py +++ b/redis/client.py @@ -530,7 +530,7 @@ def parse_client_info(value): "key1=value1 key2=value2 key3=value3" """ client_info = {} - infos = value.split(" ") + infos = str_if_bytes(value).split(" ") for info in infos: key, value = info.split("=") client_info[key] = value @@ -538,7 +538,7 @@ def parse_client_info(value): # Those fields are definded as int in networking.c for int_key in {"id", "age", "idle", "db", "sub", "psub", "multi", "qbuf", "qbuf-free", "obl", - "oll", "omem"}: + "argv-mem", "oll", "omem", "tot-mem"}: client_info[int_key] = int(client_info[int_key]) return client_info @@ -561,7 +561,7 @@ class Redis: """ RESPONSE_CALLBACKS = { **string_keys_to_dict( - 'AUTH EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST ' + 'AUTH COPY EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST ' 'PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX', bool ), @@ -620,6 +620,7 @@ class Redis: 'CLIENT ID': int, 'CLIENT KILL': parse_client_kill, 'CLIENT LIST': parse_client_list, + 'CLIENT INFO': parse_client_info, 'CLIENT SETNAME': bool_ok, 'CLIENT UNBLOCK': lambda r: r and int(r) == 1 or False, 'CLIENT PAUSE': bool_ok, @@ -1209,7 +1210,8 @@ def client_kill(self, address): "Disconnects the client at ``address`` (ip:port)" return self.execute_command('CLIENT KILL', address) - def client_kill_filter(self, _id=None, _type=None, addr=None, skipme=None): + def client_kill_filter(self, _id=None, _type=None, addr=None, + skipme=None, laddr=None): """ Disconnects client(s) using a variety of filter options :param id: Kills a client by its unique ID field @@ -1217,6 +1219,7 @@ def client_kill_filter(self, _id=None, _type=None, addr=None, skipme=None): 'master', 'slave' or 'pubsub' :param addr: Kills a client by its 'address:port' :param skipme: If True, then the client calling the command + :param laddr: Kills a cient by its 'local (bind) address:port' will not get killed even if it is identified by one of the filter options. If skipme is not provided, the server defaults to skipme=True """ @@ -1238,11 +1241,20 @@ def client_kill_filter(self, _id=None, _type=None, addr=None, skipme=None): args.extend((b'ID', _id)) if addr is not None: args.extend((b'ADDR', addr)) + if laddr is not None: + args.extend((b'LADDR', laddr)) if not args: raise DataError("CLIENT KILL ... ... " " must specify at least one filter") return self.execute_command('CLIENT KILL', *args) + def client_info(self): + """ + Returns information and statistics about the current + client connection. + """ + return self.execute_command('CLIENT INFO') + def client_list(self, _type=None): """ Returns a list of currently connected clients. @@ -1292,6 +1304,12 @@ def client_pause(self, timeout): raise DataError("CLIENT PAUSE timeout must be an integer") return self.execute_command('CLIENT PAUSE', str(timeout)) + def client_unpause(self): + """ + Unpause all redis clients + """ + return self.execute_command('CLIENT UNPAUSE') + def readwrite(self): "Disables read queries for a connection to a Redis Cluster slave node" return self.execute_command('READWRITE') @@ -1612,6 +1630,24 @@ def bitpos(self, key, bit, start=None, end=None): "when end is specified") return self.execute_command('BITPOS', *params) + def copy(self, source, destination, destination_db=None, replace=False): + """ + Copy the value stored in the ``source`` key to the ``destination`` key. + + ``destination_db`` an alternative destination database. By default, + the ``destination`` key is created in the source Redis database. + + ``replace`` whether the ``destination`` key should be removed before + copying the value to it. By default, the value is not copied if + the ``destination`` key already exists. + """ + params = [source, destination] + if destination_db is not None: + params.extend(["DB", destination_db]) + if replace: + params.append("REPLACE") + return self.execute_command('COPY', *params) + def decr(self, name, amount=1): """ Decrements the value of ``key`` by ``amount``. If no key exists, @@ -1671,6 +1707,66 @@ def get(self, name): """ return self.execute_command('GET', name) + def getdel(self, name): + """ + Get the value at key ``name`` and delete the key. This command + is similar to GET, except for the fact that it also deletes + the key on success (if and only if the key's value type + is a string). + """ + return self.execute_command('GETDEL', name) + + def getex(self, name, + ex=None, px=None, exat=None, pxat=None, persist=False): + """ + Get the value of key and optionally set its expiration. + GETEX is similar to GET, but is a write command with + additional options. All time parameters can be given as + datetime.timedelta or integers. + + ``ex`` sets an expire flag on key ``name`` for ``ex`` seconds. + + ``px`` sets an expire flag on key ``name`` for ``px`` milliseconds. + + ``exat`` sets an expire flag on key ``name`` for ``ex`` seconds, + specified in unix time. + + ``pxat`` sets an expire flag on key ``name`` for ``ex`` milliseconds, + specified in unix time. + + ``persist`` remove the time to live associated with ``name``. + """ + + pieces = [] + # similar to set command + if ex is not None: + pieces.append('EX') + if isinstance(ex, datetime.timedelta): + ex = int(ex.total_seconds()) + pieces.append(ex) + if px is not None: + pieces.append('PX') + if isinstance(px, datetime.timedelta): + px = int(px.total_seconds() * 1000) + pieces.append(px) + # similar to pexpireat command + if exat is not None: + pieces.append('EXAT') + if isinstance(exat, datetime.datetime): + s = int(exat.microsecond / 1000000) + exat = int(mod_time.mktime(exat.timetuple())) + s + pieces.append(exat) + if pxat is not None: + pieces.append('PXAT') + if isinstance(pxat, datetime.datetime): + ms = int(pxat.microsecond / 1000) + pxat = int(mod_time.mktime(pxat.timetuple())) * 1000 + ms + pieces.append(pxat) + if persist: + pieces.append('PERSIST') + + return self.execute_command('GETEX', name, *pieces) + def __getitem__(self, name): """ Return the value at key ``name``, raises a KeyError if the key @@ -1802,6 +1898,26 @@ def pttl(self, name): "Returns the number of milliseconds until the key ``name`` will expire" return self.execute_command('PTTL', name) + def hrandfield(self, key, count=None, withvalues=False): + """ + Return a random field from the hash value stored at key. + + count: if the argument is positive, return an array of distinct fields. + If called with a negative count, the behavior changes and the command + is allowed to return the same field multiple times. In this case, + the number of returned fields is the absolute value of the + specified count. + withvalues: The optional WITHVALUES modifier changes the reply so it + includes the respective values of the randomly selected hash fields. + """ + params = [] + if count is not None: + params.append(count) + if withvalues: + params.append("WITHVALUES") + + return self.execute_command("HRANDFIELD", key, *params) + def randomkey(self): "Returns the name of a random key" return self.execute_command('RANDOMKEY') @@ -2434,7 +2550,8 @@ def xack(self, name, groupname, *ids): """ return self.execute_command('XACK', name, groupname, *ids) - def xadd(self, name, fields, id='*', maxlen=None, approximate=True): + def xadd(self, name, fields, id='*', maxlen=None, approximate=True, + nomkstream=False): """ Add to a stream. name: name of the stream @@ -2442,7 +2559,7 @@ def xadd(self, name, fields, id='*', maxlen=None, approximate=True): id: Location to insert this record. By default it is appended. maxlen: truncate old stream members beyond this size approximate: actual stream length may be slightly more than maxlen - + nomkstream: When set to true, do not make a stream """ pieces = [] if maxlen is not None: @@ -2452,6 +2569,8 @@ def xadd(self, name, fields, id='*', maxlen=None, approximate=True): if approximate: pieces.append(b'~') pieces.append(str(maxlen)) + if nomkstream: + pieces.append(b'NOMKSTREAM') pieces.append(id) if not isinstance(fields, dict) or len(fields) == 0: raise DataError('XADD fields must be a non-empty dict') @@ -2747,7 +2866,8 @@ def xtrim(self, name, maxlen, approximate=True): return self.execute_command('XTRIM', name, *pieces) # SORTED SET COMMANDS - def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False): + def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False, + gt=None, lt=None): """ Set any number of element-name, score pairs to the key ``name``. Pairs are specified as a dict of element-names keys to score values. @@ -2778,6 +2898,9 @@ def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False): if incr and len(mapping) != 1: raise DataError("ZADD option 'incr' only works when passing a " "single element/score pair") + if nx is True and (gt is not None or lt is not None): + raise DataError("Only one of 'nx', 'lt', or 'gr' may be defined.") + pieces = [] options = {} if nx: @@ -2789,6 +2912,10 @@ def zadd(self, name, mapping, nx=False, xx=False, ch=False, incr=False): if incr: pieces.append(b'INCR') options['as_score'] = True + if gt: + pieces.append(b'GT') + if lt: + pieces.append(b'LT') for pair in mapping.items(): pieces.append(pair[1]) pieces.append(pair[0]) @@ -2846,6 +2973,28 @@ def zpopmin(self, name, count=None): } return self.execute_command('ZPOPMIN', name, *args, **options) + def zrandmember(self, key, count=None, withscores=False): + """ + Return a random element from the sorted set value stored at key. + + ``count`` if the argument is positive, return an array of distinct + fields. If called with a negative count, the behavior changes and + the command is allowed to return the same field multiple times. + In this case, the number of returned fields is the absolute value + of the specified count. + + ``withscores`` The optional WITHSCORES modifier changes the reply so it + includes the respective scores of the randomly selected elements from + the sorted set. + """ + params = [] + if count is not None: + params.append(count) + if withscores: + params.append("WITHSCORES") + + return self.execute_command("ZRANDMEMBER", key, *params) + def bzpopmax(self, keys, timeout=0): """ ZPOPMAX a value off of the first non-empty sorted set diff --git a/tests/test_commands.py b/tests/test_commands.py index 2da4a89e63..62394a4412 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -281,6 +281,12 @@ def test_client_list(self, r): assert isinstance(clients[0], dict) assert 'addr' in clients[0] + @skip_if_server_version_lt('6.2.0') + def test_client_info(self, r): + info = r.client_info() + assert isinstance(info, dict) + assert 'addr' in info + @skip_if_server_version_lt('5.0.0') def test_client_list_type(self, r): with pytest.raises(exceptions.RedisError): @@ -389,6 +395,26 @@ def test_client_list_after_client_setname(self, r): # we don't know which client ours will be assert 'redis_py_test' in [c['name'] for c in clients] + @skip_if_server_version_lt('6.2.0') + def test_client_kill_filter_by_laddr(self, r, r2): + r.client_setname('redis-py-c1') + r2.client_setname('redis-py-c2') + clients = [client for client in r.client_list() + if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + assert len(clients) == 2 + + clients_by_name = dict([(client.get('name'), client) + for client in clients]) + + client_2_addr = clients_by_name['redis-py-c2'].get('laddr') + resp = r.client_kill_filter(laddr=client_2_addr) + assert resp == 1 + + clients = [client for client in r.client_list() + if client.get('name') in ['redis-py-c1', 'redis-py-c2']] + assert len(clients) == 1 + assert clients[0].get('name') == 'redis-py-c1' + @skip_if_server_version_lt('2.9.50') def test_client_pause(self, r): assert r.client_pause(1) @@ -396,6 +422,10 @@ def test_client_pause(self, r): with pytest.raises(exceptions.RedisError): r.client_pause(timeout='not an integer') + @skip_if_server_version_lt('6.2.0') + def test_client_unpause(self, r): + assert r.client_unpause() == b'OK' + def test_config_get(self, r): data = r.config_get() assert 'maxmemory' in data @@ -578,6 +608,29 @@ def test_bitpos_wrong_arguments(self, r): with pytest.raises(exceptions.RedisError): r.bitpos(key, 7) == 12 + @skip_if_server_version_lt('6.2.0') + def test_copy(self, r): + assert r.copy("a", "b") == 0 + r.set("a", "foo") + assert r.copy("a", "b") == 1 + assert r.get("a") == b"foo" + assert r.get("b") == b"foo" + + @skip_if_server_version_lt('6.2.0') + def test_copy_and_replace(self, r): + r.set("a", "foo1") + r.set("b", "foo2") + assert r.copy("a", "b") == 0 + assert r.copy("a", "b", replace=True) == 1 + + @skip_if_server_version_lt('6.2.0') + def test_copy_to_another_database(self, request): + r0 = _get_client(redis.Redis, request, db=0) + r1 = _get_client(redis.Redis, request, db=1) + r0.set("a", "foo") + assert r0.copy("a", "b", destination_db=1) == 1 + assert r1.get("b") == b"foo" + def test_decr(self, r): assert r.decr('a') == -1 assert r['a'] == b'-1' @@ -704,6 +757,28 @@ def test_get_and_set(self, r): assert r.get('integer') == str(integer).encode() assert r.get('unicode_string').decode('utf-8') == unicode_string + @skip_if_server_version_lt('6.2.0') + def test_getdel(self, r): + assert r.getdel('a') is None + r.set('a', 1) + assert r.getdel('a') == b'1' + assert r.getdel('a') is None + + @skip_if_server_version_lt('6.2.0') + def test_getex(self, r): + r.set('a', 1) + assert r.getex('a') == b'1' + assert r.ttl('a') == -1 + assert r.getex('a', ex=60) == b'1' + assert r.ttl('a') == 60 + assert r.getex('a', px=6000) == b'1' + assert r.ttl('a') == 6 + expire_at = redis_server_time(r) + datetime.timedelta(minutes=1) + assert r.getex('a', pxat=expire_at) == b'1' + assert r.ttl('a') <= 60 + assert r.getex('a', persist=True) == b'1' + assert r.ttl('a') == -1 + def test_getitem_and_setitem(self, r): r['a'] = 'bar' assert r['a'] == b'bar' @@ -851,6 +926,19 @@ def test_pttl_no_key(self, r): "PTTL on servers 2.8 and after return -2 when the key doesn't exist" assert r.pttl('a') == -2 + @skip_if_server_version_lt('6.2.0') + def test_hrandfield(self, r): + assert r.hrandfield('key') is None + r.hset('key', mapping={'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5}) + assert r.hrandfield('key') is not None + assert len(r.hrandfield('key', 2)) == 2 + # with values + assert len(r.hrandfield('key', 2, True)) == 4 + # without duplications + assert len(r.hrandfield('key', 10)) == 5 + # with duplications + assert len(r.hrandfield('key', -10)) == 10 + def test_randomkey(self, r): assert r.randomkey() is None for key in ('a', 'b', 'c'): @@ -1374,6 +1462,23 @@ def test_zadd_incr_with_xx(self, r): # redis-py assert r.zadd('a', {'a1': 1}, xx=True, incr=True) is None + @skip_if_server_version_lt('6.2.0') + def test_zadd_gt_lt(self, r): + + for i in range(1, 20): + r.zadd('a', {'a%s' % i: i}) + assert r.zadd('a', {'a20': 5}, gt=3) == 1 + + for i in range(1, 20): + r.zadd('a', {'a%s' % i: i}) + assert r.zadd('a', {'a2': 5}, lt=1) == 0 + + # cannot use both nx and xx options + with pytest.raises(exceptions.DataError): + r.zadd('a', {'a15': 155}, nx=True, lt=True) + r.zadd('a', {'a15': 155}, nx=True, gt=True) + r.zadd('a', {'a15': 155}, lx=True, gt=True) + def test_zcard(self, r): r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3}) assert r.zcard('a') == 3 @@ -1449,6 +1554,18 @@ def test_zpopmin(self, r): assert r.zpopmin('a', count=2) == \ [(b'a2', 2), (b'a3', 3)] + @skip_if_server_version_lt('6.2.0') + def test_zrandemember(self, r): + r.zadd('a', {'a1': 1, 'a2': 2, 'a3': 3, 'a4': 4, 'a5': 5}) + assert r.zrandmember('a') is not None + assert len(r.zrandmember('a', 2)) == 2 + # with scores + assert len(r.zrandmember('a', 2, True)) == 4 + # without duplications + assert len(r.zrandmember('a', 10)) == 5 + # with duplications + assert len(r.zrandmember('a', -10)) == 10 + @skip_if_server_version_lt('4.9.0') def test_bzpopmax(self, r): r.zadd('a', {'a1': 1, 'a2': 2}) @@ -2190,6 +2307,16 @@ def test_xadd(self, r): r.xadd(stream, {'foo': 'bar'}, maxlen=2, approximate=False) assert r.xlen(stream) == 2 + @skip_if_server_version_lt('6.2.0') + def test_xadd_nomkstream(self, r): + # nomkstream option + stream = 'stream' + r.xadd(stream, {'foo': 'bar'}) + r.xadd(stream, {'some': 'other'}, nomkstream=False) + assert r.xlen(stream) == 2 + r.xadd(stream, {'some': 'other'}, nomkstream=True) + assert r.xlen(stream) == 3 + @skip_if_server_version_lt('5.0.0') def test_xclaim(self, r): stream = 'stream'