Skip to content

Commit 858b440

Browse files
committed
Update tests
1 parent a3b9d8a commit 858b440

File tree

1 file changed

+23
-18
lines changed

1 file changed

+23
-18
lines changed

tests/test_asyncio/test_connection.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import redis
99
from redis.asyncio.connection import (
10+
BaseParser,
1011
Connection,
1112
PythonParser,
1213
UnixDomainSocketConnection,
@@ -24,16 +25,19 @@ async def test_invalid_response(create_redis):
2425
r = await create_redis(single_connection_client=True)
2526

2627
raw = b"x"
28+
fake_stream = FakeStream(raw + b"\r\n")
2729

28-
parser: "PythonParser" = r.connection._parser
29-
if not isinstance(parser, PythonParser):
30-
pytest.skip("PythonParser only")
31-
stream_mock = mock.Mock(parser._stream)
32-
stream_mock.readline.return_value = raw + b"\r\n"
33-
with mock.patch.object(parser, "_stream", stream_mock):
30+
parser: BaseParser = r.connection._parser
31+
with mock.patch.object(parser, "_stream", fake_stream):
3432
with pytest.raises(InvalidResponse) as cm:
3533
await parser.read_response()
36-
assert str(cm.value) == f"Protocol Error: {raw!r}"
34+
if isinstance(parser, PythonParser):
35+
assert str(cm.value) == f"Protocol Error: {raw!r}"
36+
else:
37+
assert (
38+
str(cm.value) == f'Protocol error, got "{raw.decode()}" as reply type byte'
39+
)
40+
await r.connection.disconnect()
3741

3842

3943
@skip_if_server_version_lt("4.0.0")
@@ -115,26 +119,27 @@ async def test_connect_timeout_error_without_retry():
115119
assert str(e.value) == "Timeout connecting to server"
116120

117121

118-
class TestError(BaseException):
119-
pass
120-
121-
122-
class InterruptingReader:
122+
class FakeStream:
123123
"""
124124
A class simulating an asyncio input buffer, but raising a
125125
special exception every other read.
126126
"""
127127

128-
def __init__(self, data):
128+
class TestError(BaseException):
129+
pass
130+
131+
def __init__(self, data, interrupt_every=0):
129132
self.data = data
130133
self.counter = 0
131134
self.pos = 0
135+
self.interrupt_every = interrupt_every
132136

133137
def tick(self):
134138
self.counter += 1
135-
# return
136-
if (self.counter % 2) == 0:
137-
raise TestError()
139+
if not self.interrupt_every:
140+
return
141+
if (self.counter % self.interrupt_every) == 0:
142+
raise self.TestError()
138143

139144
async def read(self, want):
140145
self.tick()
@@ -176,12 +181,12 @@ async def test_connection_parse_response_resume(r: redis.Redis):
176181
b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
177182
)
178183

179-
conn._parser._stream = InterruptingReader(message)
184+
conn._parser._stream = FakeStream(message, interrupt_every=2)
180185
for i in range(100):
181186
try:
182187
response = await conn.read_response()
183188
break
184-
except TestError:
189+
except FakeStream.TestError:
185190
pass
186191

187192
else:

0 commit comments

Comments
 (0)