diff --git a/CHANGES b/CHANGES index 9d82341a76..09babb706f 100644 --- a/CHANGES +++ b/CHANGES @@ -1,3 +1,4 @@ + * Use hiredis-py pack_command if available. * Support `.unlink()` in ClusterPipeline * Simplify synchronous SocketBuffer state management * Fix string cleanse in Redis Graph diff --git a/redis/connection.py b/redis/connection.py old mode 100755 new mode 100644 index 57f0a3a81e..114221d8e9 --- a/redis/connection.py +++ b/redis/connection.py @@ -3,6 +3,7 @@ import io import os import socket +import sys import threading import weakref from io import SEEK_END @@ -32,7 +33,12 @@ TimeoutError, ) from redis.retry import Retry -from redis.utils import CRYPTOGRAPHY_AVAILABLE, HIREDIS_AVAILABLE, str_if_bytes +from redis.utils import ( + CRYPTOGRAPHY_AVAILABLE, + HIREDIS_AVAILABLE, + HIREDIS_PACK_AVAILABLE, + str_if_bytes, +) try: import ssl @@ -509,6 +515,75 @@ def read_response(self, disable_decoding=False): DefaultParser = PythonParser +class HiredisRespSerializer: + def pack(self, *args): + """Pack a series of arguments into the Redis protocol""" + output = [] + + if isinstance(args[0], str): + args = tuple(args[0].encode().split()) + args[1:] + elif b" " in args[0]: + args = tuple(args[0].split()) + args[1:] + try: + output.append(hiredis.pack_command(args)) + except TypeError: + _, value, traceback = sys.exc_info() + raise DataError(value).with_traceback(traceback) + + return output + + +class PythonRespSerializer: + def __init__(self, buffer_cutoff, encode) -> None: + self._buffer_cutoff = buffer_cutoff + self.encode = encode + + def pack(self, *args): + """Pack a series of arguments into the Redis protocol""" + output = [] + # the client might have included 1 or more literal arguments in + # the command name, e.g., 'CONFIG GET'. The Redis server expects these + # arguments to be sent separately, so split the first argument + # manually. These arguments should be bytestrings so that they are + # not encoded. + if isinstance(args[0], str): + args = tuple(args[0].encode().split()) + args[1:] + elif b" " in args[0]: + args = tuple(args[0].split()) + args[1:] + + buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) + + buffer_cutoff = self._buffer_cutoff + for arg in map(self.encode, args): + # to avoid large string mallocs, chunk the command into the + # output list if we're sending large values or memoryviews + arg_length = len(arg) + if ( + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) + ): + buff = SYM_EMPTY.join( + (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) + ) + output.append(buff) + output.append(arg) + buff = SYM_CRLF + else: + buff = SYM_EMPTY.join( + ( + buff, + SYM_DOLLAR, + str(arg_length).encode(), + SYM_CRLF, + arg, + SYM_CRLF, + ) + ) + output.append(buff) + return output + + class Connection: "Manages TCP communication to and from a Redis server" @@ -536,6 +611,7 @@ def __init__( retry=None, redis_connect_func=None, credential_provider: Optional[CredentialProvider] = None, + command_packer=None, ): """ Initialize a new Connection. @@ -590,6 +666,7 @@ def __init__( self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 + self._command_packer = self._construct_command_packer(command_packer) def __repr__(self): repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()]) @@ -607,6 +684,14 @@ def __del__(self): except Exception: pass + def _construct_command_packer(self, packer): + if packer is not None: + return packer + elif HIREDIS_PACK_AVAILABLE: + return HiredisRespSerializer() + else: + return PythonRespSerializer(self._buffer_cutoff, self.encoder.encode) + def register_connect_callback(self, callback): self._connect_callbacks.append(weakref.WeakMethod(callback)) @@ -827,7 +912,8 @@ def send_packed_command(self, command, check_health=True): def send_command(self, *args, **kwargs): """Pack and send a command to the Redis server""" self.send_packed_command( - self.pack_command(*args), check_health=kwargs.get("check_health", True) + self._command_packer.pack(*args), + check_health=kwargs.get("check_health", True), ) def can_read(self, timeout=0): @@ -872,48 +958,7 @@ def read_response(self, disable_decoding=False): def pack_command(self, *args): """Pack a series of arguments into the Redis protocol""" - output = [] - # the client might have included 1 or more literal arguments in - # the command name, e.g., 'CONFIG GET'. The Redis server expects these - # arguments to be sent separately, so split the first argument - # manually. These arguments should be bytestrings so that they are - # not encoded. - if isinstance(args[0], str): - args = tuple(args[0].encode().split()) + args[1:] - elif b" " in args[0]: - args = tuple(args[0].split()) + args[1:] - - buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) - - buffer_cutoff = self._buffer_cutoff - for arg in map(self.encoder.encode, args): - # to avoid large string mallocs, chunk the command into the - # output list if we're sending large values or memoryviews - arg_length = len(arg) - if ( - len(buff) > buffer_cutoff - or arg_length > buffer_cutoff - or isinstance(arg, memoryview) - ): - buff = SYM_EMPTY.join( - (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) - ) - output.append(buff) - output.append(arg) - buff = SYM_CRLF - else: - buff = SYM_EMPTY.join( - ( - buff, - SYM_DOLLAR, - str(arg_length).encode(), - SYM_CRLF, - arg, - SYM_CRLF, - ) - ) - output.append(buff) - return output + return self._command_packer.pack(*args) def pack_commands(self, commands): """Pack multiple commands into the Redis protocol""" @@ -923,7 +968,7 @@ def pack_commands(self, commands): buffer_cutoff = self._buffer_cutoff for cmd in commands: - for chunk in self.pack_command(*cmd): + for chunk in self._command_packer.pack(*cmd): chunklen = len(chunk) if ( buffer_length > buffer_cutoff diff --git a/redis/utils.py b/redis/utils.py index 693d4e64b5..d95e62c042 100644 --- a/redis/utils.py +++ b/redis/utils.py @@ -7,8 +7,10 @@ # Only support Hiredis >= 1.0: HIREDIS_AVAILABLE = not hiredis.__version__.startswith("0.") + HIREDIS_PACK_AVAILABLE = hasattr(hiredis, "pack_command") except ImportError: HIREDIS_AVAILABLE = False + HIREDIS_PACK_AVAILABLE = False try: import cryptography # noqa diff --git a/tests/test_encoding.py b/tests/test_encoding.py index 2867640742..cb9c4e20be 100644 --- a/tests/test_encoding.py +++ b/tests/test_encoding.py @@ -2,6 +2,7 @@ import redis from redis.connection import Connection +from redis.utils import HIREDIS_PACK_AVAILABLE from .conftest import _get_client @@ -75,6 +76,10 @@ def test_replace(self, request): assert r.get("a") == "foo\ufffd" +@pytest.mark.skipif( + HIREDIS_PACK_AVAILABLE, + reason="Packing via hiredis does not preserve memoryviews", +) class TestMemoryviewsAreNotPacked: def test_memoryviews_are_not_packed(self): c = Connection()