Skip to content

Commit a97064e

Browse files
committed
Move MockStream and MockSocket into their own files
1 parent 445861d commit a97064e

File tree

4 files changed

+101
-95
lines changed

4 files changed

+101
-95
lines changed

tests/mocks.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Various mocks for testing
2+
3+
4+
class MockSocket:
5+
"""
6+
A class simulating an readable socket, optionally raising a
7+
special exception every other read.
8+
"""
9+
10+
class TestError(BaseException):
11+
pass
12+
13+
def __init__(self, data, interrupt_every=0):
14+
self.data = data
15+
self.counter = 0
16+
self.pos = 0
17+
self.interrupt_every = interrupt_every
18+
19+
def tick(self):
20+
self.counter += 1
21+
if not self.interrupt_every:
22+
return
23+
if (self.counter % self.interrupt_every) == 0:
24+
raise self.TestError()
25+
26+
def recv(self, bufsize):
27+
self.tick()
28+
bufsize = min(5, bufsize) # truncate the read size
29+
result = self.data[self.pos : self.pos + bufsize]
30+
self.pos += len(result)
31+
return result
32+
33+
def recv_into(self, buffer, nbytes=0, flags=0):
34+
self.tick()
35+
if nbytes == 0:
36+
nbytes = len(buffer)
37+
nbytes = min(5, nbytes) # truncate the read size
38+
result = self.data[self.pos : self.pos + nbytes]
39+
self.pos += len(result)
40+
buffer[: len(result)] = result
41+
return len(result)

tests/test_asyncio/mocks.py

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import asyncio
2+
3+
# Helper Mocking classes for the tests.
4+
5+
6+
class MockStream:
7+
"""
8+
A class simulating an asyncio input buffer, optionally raising a
9+
special exception every other read.
10+
"""
11+
12+
class TestError(BaseException):
13+
pass
14+
15+
def __init__(self, data, interrupt_every=0):
16+
self.data = data
17+
self.counter = 0
18+
self.pos = 0
19+
self.interrupt_every = interrupt_every
20+
21+
def tick(self):
22+
self.counter += 1
23+
if not self.interrupt_every:
24+
return
25+
if (self.counter % self.interrupt_every) == 0:
26+
raise self.TestError()
27+
28+
async def read(self, want):
29+
self.tick()
30+
want = 5
31+
result = self.data[self.pos : self.pos + want]
32+
self.pos += len(result)
33+
return result
34+
35+
async def readline(self):
36+
self.tick()
37+
find = self.data.find(b"\n", self.pos)
38+
if find >= 0:
39+
result = self.data[self.pos : find + 1]
40+
else:
41+
result = self.data[self.pos :]
42+
self.pos += len(result)
43+
return result
44+
45+
async def readexactly(self, length):
46+
self.tick()
47+
result = self.data[self.pos : self.pos + length]
48+
if len(result) < length:
49+
raise asyncio.IncompleteReadError(result, None)
50+
self.pos += len(result)
51+
return result

tests/test_asyncio/test_connection.py

+4-51
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
from tests.conftest import skip_if_server_version_lt
1919

2020
from .compat import mock
21+
from .mocks import MockStream
2122

2223

2324
@pytest.mark.onlynoncluster
2425
async def test_invalid_response(create_redis):
2526
r = await create_redis(single_connection_client=True)
2627

2728
raw = b"x"
28-
fake_stream = FakeStream(raw + b"\r\n")
29+
fake_stream = MockStream(raw + b"\r\n")
2930

3031
parser: BaseParser = r.connection._parser
3132
with mock.patch.object(parser, "_stream", fake_stream):
@@ -119,54 +120,6 @@ async def test_connect_timeout_error_without_retry():
119120
assert str(e.value) == "Timeout connecting to server"
120121

121122

122-
class FakeStream:
123-
"""
124-
A class simulating an asyncio input buffer, but raising a
125-
special exception every other read.
126-
"""
127-
128-
class TestError(BaseException):
129-
pass
130-
131-
def __init__(self, data, interrupt_every=0):
132-
self.data = data
133-
self.counter = 0
134-
self.pos = 0
135-
self.interrupt_every = interrupt_every
136-
137-
def tick(self):
138-
self.counter += 1
139-
if not self.interrupt_every:
140-
return
141-
if (self.counter % self.interrupt_every) == 0:
142-
raise self.TestError()
143-
144-
async def read(self, want):
145-
self.tick()
146-
want = 5
147-
result = self.data[self.pos : self.pos + want]
148-
self.pos += len(result)
149-
return result
150-
151-
async def readline(self):
152-
self.tick()
153-
find = self.data.find(b"\n", self.pos)
154-
if find >= 0:
155-
result = self.data[self.pos : find + 1]
156-
else:
157-
result = self.data[self.pos :]
158-
self.pos += len(result)
159-
return result
160-
161-
async def readexactly(self, length):
162-
self.tick()
163-
result = self.data[self.pos : self.pos + length]
164-
if len(result) < length:
165-
raise asyncio.IncompleteReadError(result, None)
166-
self.pos += len(result)
167-
return result
168-
169-
170123
@pytest.mark.onlynoncluster
171124
async def test_connection_parse_response_resume(r: redis.Redis):
172125
"""
@@ -181,12 +134,12 @@ async def test_connection_parse_response_resume(r: redis.Redis):
181134
b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
182135
)
183136

184-
conn._parser._stream = FakeStream(message, interrupt_every=2)
137+
conn._parser._stream = MockStream(message, interrupt_every=2)
185138
for i in range(100):
186139
try:
187140
response = await conn.read_response()
188141
break
189-
except FakeStream.TestError:
142+
except MockStream.TestError:
190143
pass
191144

192145
else:

tests/test_connection.py

+5-44
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from redis.utils import HIREDIS_AVAILABLE
1414

1515
from .conftest import skip_if_server_version_lt
16+
from .mocks import MockSocket
1617

1718

1819
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
@@ -125,46 +126,6 @@ def test_connect_timeout_error_without_retry(self):
125126
self.clear(conn)
126127

127128

128-
class FakeSocket:
129-
"""
130-
A class simulating an readable socket, but raising a
131-
special exception every other read.
132-
"""
133-
134-
class TestError(BaseException):
135-
pass
136-
137-
def __init__(self, data, interrupt_every=0):
138-
self.data = data
139-
self.counter = 0
140-
self.pos = 0
141-
self.interrupt_every = interrupt_every
142-
143-
def tick(self):
144-
self.counter += 1
145-
if not self.interrupt_every:
146-
return
147-
if (self.counter % self.interrupt_every) == 0:
148-
raise self.TestError()
149-
150-
def recv(self, bufsize):
151-
self.tick()
152-
bufsize = min(5, bufsize) # truncate the read size
153-
result = self.data[self.pos : self.pos + bufsize]
154-
self.pos += len(result)
155-
return result
156-
157-
def recv_into(self, buffer, nbytes=0, flags=0):
158-
self.tick()
159-
if nbytes == 0:
160-
nbytes = len(buffer)
161-
nbytes = min(5, nbytes) # truncate the read size
162-
result = self.data[self.pos : self.pos + nbytes]
163-
self.pos += len(result)
164-
buffer[: len(result)] = result
165-
return len(result)
166-
167-
168129
@pytest.mark.onlynoncluster
169130
@pytest.mark.parametrize(
170131
"parser_class", [PythonParser, HiredisParser], ids=["PythonParser", "HiredisParser"]
@@ -185,17 +146,17 @@ def test_connection_parse_response_resume(r: redis.Redis, parser_class):
185146
b"*3\r\n$7\r\nmessage\r\n$8\r\nchannel1\r\n"
186147
b"$25\r\nhi\r\nthere\r\n+how\r\nare\r\nyou\r\n"
187148
)
188-
fake_socket = FakeSocket(message, interrupt_every=2)
149+
mock_socket = MockSocket(message, interrupt_every=2)
189150

190151
if isinstance(conn._parser, PythonParser):
191-
conn._parser._buffer._sock = fake_socket
152+
conn._parser._buffer._sock = mock_socket
192153
else:
193-
conn._parser._sock = fake_socket
154+
conn._parser._sock = mock_socket
194155
for i in range(100):
195156
try:
196157
response = conn.read_response()
197158
break
198-
except FakeSocket.TestError:
159+
except MockSocket.TestError:
199160
pass
200161

201162
else:

0 commit comments

Comments
 (0)