7
7
8
8
import redis
9
9
from redis .asyncio .connection import (
10
+ BaseParser ,
10
11
Connection ,
11
12
PythonParser ,
12
13
UnixDomainSocketConnection ,
@@ -24,16 +25,19 @@ async def test_invalid_response(create_redis):
24
25
r = await create_redis (single_connection_client = True )
25
26
26
27
raw = b"x"
28
+ fake_stream = FakeStream (raw + b"\r \n " )
27
29
28
- parser : "PythonParser" = r .connection ._parser
29
- if not isinstance (parser , PythonParser ):
30
- pytest .skip ("PythonParser only" )
31
- stream_mock = mock .Mock (parser ._stream )
32
- stream_mock .readline .return_value = raw + b"\r \n "
33
- with mock .patch .object (parser , "_stream" , stream_mock ):
30
+ parser : BaseParser = r .connection ._parser
31
+ with mock .patch .object (parser , "_stream" , fake_stream ):
34
32
with pytest .raises (InvalidResponse ) as cm :
35
33
await parser .read_response ()
36
- assert str (cm .value ) == f"Protocol Error: { raw !r} "
34
+ if isinstance (parser , PythonParser ):
35
+ assert str (cm .value ) == f"Protocol Error: { raw !r} "
36
+ else :
37
+ assert (
38
+ str (cm .value ) == f'Protocol error, got "{ raw .decode ()} " as reply type byte'
39
+ )
40
+ await r .connection .disconnect ()
37
41
38
42
39
43
@skip_if_server_version_lt ("4.0.0" )
@@ -115,26 +119,27 @@ async def test_connect_timeout_error_without_retry():
115
119
assert str (e .value ) == "Timeout connecting to server"
116
120
117
121
118
- class TestError (BaseException ):
119
- pass
120
-
121
-
122
- class InterruptingReader :
122
+ class FakeStream :
123
123
"""
124
124
A class simulating an asyncio input buffer, but raising a
125
125
special exception every other read.
126
126
"""
127
127
128
- def __init__ (self , data ):
128
+ class TestError (BaseException ):
129
+ pass
130
+
131
+ def __init__ (self , data , interrupt_every = 0 ):
129
132
self .data = data
130
133
self .counter = 0
131
134
self .pos = 0
135
+ self .interrupt_every = interrupt_every
132
136
133
137
def tick (self ):
134
138
self .counter += 1
135
- # return
136
- if (self .counter % 2 ) == 0 :
137
- raise TestError ()
139
+ if not self .interrupt_every :
140
+ return
141
+ if (self .counter % self .interrupt_every ) == 0 :
142
+ raise self .TestError ()
138
143
139
144
async def read (self , want ):
140
145
self .tick ()
@@ -176,12 +181,12 @@ async def test_connection_parse_response_resume(r: redis.Redis):
176
181
b"$25\r \n hi\r \n there\r \n +how\r \n are\r \n you\r \n "
177
182
)
178
183
179
- conn ._parser ._stream = InterruptingReader (message )
184
+ conn ._parser ._stream = FakeStream (message , interrupt_every = 2 )
180
185
for i in range (100 ):
181
186
try :
182
187
response = await conn .read_response ()
183
188
break
184
- except TestError :
189
+ except FakeStream . TestError :
185
190
pass
186
191
187
192
else :
0 commit comments