Skip to content

Commit ef8a38e

Browse files
committed
Port changes from #3424 to main / v7:
- Raise ``StopAsyncIteration`` if a ``ConnectionClosedOK`` is raised by the server in the websocket provider. Handle this appropriately instead of handling the exception directly. - Add a sanity check for AsyncIPC as well to stop iteration when / if the reader does not exist.
1 parent 39cde3b commit ef8a38e

File tree

6 files changed

+44
-36
lines changed

6 files changed

+44
-36
lines changed

newsfragments/3432.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Handle ``ConnectionClosedOK`` case for ``WebSocketProvider``. If the connection is closed gracefully, debug log and silently break out of the message iterator loop by raising ``StopAsyncIteration``.

web3/_utils/module_testing/module_testing_utils.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
from collections import (
2-
deque,
3-
)
1+
import asyncio
42
from typing import (
53
TYPE_CHECKING,
64
Any,
@@ -179,7 +177,9 @@ class WebSocketMessageStreamMock:
179177
def __init__(
180178
self, messages: Collection[bytes] = None, raise_exception: Exception = None
181179
) -> None:
182-
self.messages = deque(messages) if messages else deque()
180+
self.queue = asyncio.Queue() # type: ignore # py38 issue
181+
for msg in messages or []:
182+
self.queue.put_nowait(msg)
183183
self.raise_exception = raise_exception
184184

185185
def __await__(self) -> Generator[Any, Any, "Self"]:
@@ -192,13 +192,12 @@ def __aiter__(self) -> "Self":
192192
return self
193193

194194
async def __anext__(self) -> bytes:
195+
return await self.queue.get()
196+
197+
async def recv(self) -> bytes:
195198
if self.raise_exception:
196199
raise self.raise_exception
197-
198-
elif len(self.messages) == 0:
199-
raise StopAsyncIteration
200-
201-
return self.messages.popleft()
200+
return await self.queue.get()
202201

203202
@staticmethod
204203
async def pong() -> Literal[False]:

web3/manager.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
from hexbytes import (
2121
HexBytes,
2222
)
23-
from websockets.exceptions import (
24-
ConnectionClosedOK,
25-
)
2623

2724
from web3._utils.batching import (
2825
RequestBatcher,
@@ -586,7 +583,4 @@ def __aiter__(self) -> Self:
586583
return self
587584

588585
async def __anext__(self) -> RPCResponse:
589-
try:
590-
return await self.manager._get_next_message()
591-
except ConnectionClosedOK:
592-
raise StopAsyncIteration
586+
return await self.manager._get_next_message()

web3/providers/persistent/async_ipc.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,8 +171,15 @@ async def make_batch_request(
171171
return response
172172

173173
async def _provider_specific_message_listener(self) -> None:
174-
self._raw_message += to_text(await self._reader.read(4096)).lstrip()
174+
if self._reader is None:
175+
# sanity check to ensure the reader is initialized
176+
self.logger.debug(
177+
"IPC reader is not initialized. If this was not expected, "
178+
"check the connection status."
179+
)
180+
raise StopAsyncIteration
175181

182+
self._raw_message += to_text(await self._reader.read(4096)).lstrip()
176183
while self._raw_message:
177184
try:
178185
response, pos = self._decoder.raw_decode(self._raw_message)

web3/providers/persistent/persistent.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ async def _message_listener(self) -> None:
160160
await asyncio.sleep(0)
161161
try:
162162
await self._provider_specific_message_listener()
163+
except StopAsyncIteration:
164+
raise
163165
except Exception as e:
164166
if not self.silence_listener_task_exceptions:
165167
raise e
@@ -202,10 +204,6 @@ async def _match_response_id_to_request_id() -> RPCResponse:
202204
request_cache_key = generate_cache_key(request_id)
203205

204206
while True:
205-
# check if an exception was recorded in the listener task and raise it
206-
# in the main loop if so
207-
self._handle_listener_task_exceptions()
208-
209207
if request_cache_key in self._request_processor._request_response_cache:
210208
self.logger.debug(
211209
f"Popping response for id {request_id} from cache."
@@ -215,6 +213,9 @@ async def _match_response_id_to_request_id() -> RPCResponse:
215213
)
216214
return popped_response
217215
else:
216+
# check if an exception was recorded in the listener task and raise
217+
# it in the main loop if so
218+
self._handle_listener_task_exceptions()
218219
await asyncio.sleep(0)
219220

220221
try:

web3/providers/persistent/websocket.py

Lines changed: 21 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
connect,
2626
)
2727
from websockets.exceptions import (
28+
ConnectionClosedOK,
2829
WebSocketException,
2930
)
3031

@@ -169,18 +170,23 @@ async def make_batch_request(
169170
return response
170171

171172
async def _provider_specific_message_listener(self) -> None:
172-
async for raw_message in self._ws:
173-
await asyncio.sleep(0)
174-
175-
response = json.loads(raw_message)
176-
if isinstance(response, list):
177-
response = sort_batch_response_by_response_ids(response)
178-
179-
subscription = (
180-
response.get("method") == "eth_subscription"
181-
if not isinstance(response, list)
182-
else False
183-
)
184-
await self._request_processor.cache_raw_response(
185-
response, subscription=subscription
186-
)
173+
try:
174+
while True:
175+
raw_message = await self._ws.recv()
176+
await asyncio.sleep(0)
177+
178+
response = json.loads(raw_message)
179+
if isinstance(response, list):
180+
response = sort_batch_response_by_response_ids(response)
181+
182+
subscription = (
183+
response.get("method") == "eth_subscription"
184+
if not isinstance(response, list)
185+
else False
186+
)
187+
await self._request_processor.cache_raw_response(
188+
response, subscription=subscription
189+
)
190+
except ConnectionClosedOK:
191+
self.logger.debug("WebSocket connection closed gracefully.")
192+
raise StopAsyncIteration

0 commit comments

Comments
 (0)