Skip to content

Commit 157be78

Browse files
committed
Make syncronouse PythonParser restartable on error, same as HiredisParser
1 parent c15241e commit 157be78

File tree

2 files changed

+100
-14
lines changed

2 files changed

+100
-14
lines changed

redis/connection.py

+19-13
Original file line numberDiff line numberDiff line change
@@ -232,12 +232,6 @@ def read(self, length):
232232
self._buffer.seek(self.bytes_read)
233233
data = self._buffer.read(length)
234234
self.bytes_read += len(data)
235-
236-
# purge the buffer when we've consumed it all so it doesn't
237-
# grow forever
238-
if self.bytes_read == self.bytes_written:
239-
self.purge()
240-
241235
return data[:-2]
242236

243237
def readline(self):
@@ -251,15 +245,18 @@ def readline(self):
251245
data = buf.readline()
252246

253247
self.bytes_read += len(data)
254-
255-
# purge the buffer when we've consumed it all so it doesn't
256-
# grow forever
257-
if self.bytes_read == self.bytes_written:
258-
self.purge()
259-
260248
return data[:-2]
261249

250+
def rewind(self):
251+
"""
252+
Rewind the buffer to the beginning, to re-start reading.
253+
"""
254+
self.bytes_read = 0
255+
262256
def purge(self):
257+
"""
258+
After a successful read, purge the buffer
259+
"""
263260
self._buffer.seek(0)
264261
self._buffer.truncate()
265262
self.bytes_written = 0
@@ -315,6 +312,15 @@ def can_read(self, timeout):
315312
return self._buffer and self._buffer.can_read(timeout)
316313

317314
def read_response(self, disable_decoding=False):
315+
try:
316+
result = self._read_response(disable_decoding=disable_decoding)
317+
self._buffer.purge()
318+
return result
319+
except BaseException:
320+
self._buffer.rewind()
321+
raise
322+
323+
def _read_response(self, disable_decoding=False):
318324
raw = self._buffer.readline()
319325
if not raw:
320326
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)
@@ -355,7 +361,7 @@ def read_response(self, disable_decoding=False):
355361
if length == -1:
356362
return None
357363
response = [
358-
self.read_response(disable_decoding=disable_decoding)
364+
self._read_response(disable_decoding=disable_decoding)
359365
for i in range(length)
360366
]
361367
if isinstance(response, bytes) and disable_decoding is False:

tests/test_connection.py

+81-1
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55

66
import pytest
77

8+
import redis
89
from redis.backoff import NoBackoff
9-
from redis.connection import Connection
10+
from redis.connection import Connection, PythonParser, HiredisParser
1011
from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
1112
from redis.retry import Retry
1213
from redis.utils import HIREDIS_AVAILABLE
@@ -122,3 +123,82 @@ def test_connect_timeout_error_without_retry(self):
122123
assert conn._connect.call_count == 1
123124
assert str(e.value) == "Timeout connecting to server"
124125
self.clear(conn)
126+
127+
128+
class FakeSocket:
129+
"""
130+
A class simulating an readable socket, but raising a
131+
special exception every other read.
132+
"""
133+
134+
class TestError(BaseException):
135+
pass
136+
137+
def __init__(self, data, interrupt_every=0):
138+
self.data = data
139+
self.counter = 0
140+
self.pos = 0
141+
self.interrupt_every = interrupt_every
142+
143+
def tick(self):
144+
self.counter += 1
145+
if not self.interrupt_every:
146+
return
147+
if (self.counter % self.interrupt_every) == 0:
148+
raise self.TestError()
149+
150+
def recv(self, bufsize):
151+
self.tick()
152+
bufsize = min(5, bufsize) # truncate the read size
153+
result = self.data[self.pos : self.pos + bufsize]
154+
self.pos += len(result)
155+
return result
156+
157+
def recv_into(self, buffer, nbytes=0, flags=0):
158+
self.tick()
159+
if nbytes == 0:
160+
nbytes = len(buffer)
161+
nbytes = min(5, nbytes) # truncate the read size
162+
result = self.data[self.pos : self.pos + nbytes]
163+
self.pos += len(result)
164+
buffer[: len(result)] = result
165+
return len(result)
166+
167+
168+
@pytest.mark.onlynoncluster
169+
@pytest.mark.parametrize(
170+
"parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"]
171+
)
172+
def test_connection_parse_response_resume(r: redis.Redis, parser_class):
173+
"""
174+
This test verifies that the Connection parser,
175+
be that PythonParser or HiredisParser,
176+
can be interrupted at IO time and then resume parsing.
177+
"""
178+
if parser_class is HiredisParser and not HIREDIS_AVAILABLE:
179+
pytest.skip("Hiredis not available)")
180+
args = dict(r.connection_pool.connection_kwargs)
181+
args["parser_class"] = parser_class
182+
conn = Connection(**args)
183+
conn.connect()
184+
message = (
185+
b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n"
186+
b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
187+
)
188+
fake_socket = FakeSocket(message, interrupt_every=2)
189+
190+
if isinstance(conn._parser, PythonParser):
191+
conn._parser._buffer._sock = fake_socket
192+
else:
193+
conn._parser._sock = fake_socket
194+
for i in range(100):
195+
try:
196+
response = conn.read_response()
197+
break
198+
except FakeSocket.TestError:
199+
pass
200+
201+
else:
202+
pytest.fail("didn't receive a response")
203+
assert response
204+
assert i > 0

0 commit comments

Comments
 (0)