diff --git a/docs/providers.rst b/docs/providers.rst index 020bc0d663..0bb35d4f69 100644 --- a/docs/providers.rst +++ b/docs/providers.rst @@ -270,16 +270,15 @@ asynchronous context manager, can be found in the `websockets connection`_ docs. ... # subscribe to new block headers ... subscription_id = await w3.eth.subscribe("newHeads") ... - ... unsubscribed = False - ... while not unsubscribed: - ... async for response in w3.ws.listen_to_websocket(): - ... print(f"{response}\n") - ... # handle responses here + ... async for response in w3.ws.listen_to_websocket(): + ... print(f"{response}\n") + ... # handle responses here ... - ... if some_condition: - ... # unsubscribe from new block headers - ... unsubscribed = await w3.eth.unsubscribe(subscription_id) - ... break + ... if some_condition: + ... # unsubscribe from new block headers and break out of + ... # iterator + ... await w3.eth.unsubscribe(subscription_id) + ... break ... ... # still an open connection, make any other requests and get ... # responses via send / receive diff --git a/newsfragments/3116.breaking.rst b/newsfragments/3116.breaking.rst new file mode 100644 index 0000000000..cfcc5477cb --- /dev/null +++ b/newsfragments/3116.breaking.rst @@ -0,0 +1 @@ +Refactor the async iterator pattern for message streams from the websocket connection for ``WebsocketProviderV2`` to a proper async iterator. This allows for a more natural usage of the iterator pattern and mimics the behavior of the underlying ``websockets`` library. diff --git a/newsfragments/3116.bugfix.rst b/newsfragments/3116.bugfix.rst new file mode 100644 index 0000000000..03015df060 --- /dev/null +++ b/newsfragments/3116.bugfix.rst @@ -0,0 +1 @@ +Fix issues with formatting middleware, such as ``async_geth_poa_middleware`` and subscription responses for ``WebsocketProviderV2``. diff --git a/newsfragments/3116.docs.rst b/newsfragments/3116.docs.rst new file mode 100644 index 0000000000..f7b7bc50fb --- /dev/null +++ b/newsfragments/3116.docs.rst @@ -0,0 +1 @@ +Updates to the ``WebsocketProviderV2`` documentation async iterator example for iterating over a persistent stream of messages from the websocket connection via ``async for``. diff --git a/tests/core/providers/test_wsv2_provider.py b/tests/core/providers/test_wsv2_provider.py new file mode 100644 index 0000000000..983e1de322 --- /dev/null +++ b/tests/core/providers/test_wsv2_provider.py @@ -0,0 +1,121 @@ +import json +import pytest +import sys + +from eth_utils import ( + to_bytes, +) + +from web3.exceptions import ( + TimeExhausted, +) +from web3.providers.websocket import ( + WebsocketProviderV2, +) +from web3.types import ( + RPCEndpoint, +) + + +def _mock_ws(provider): + # move to top of file when python 3.7 is no longer supported in web3.py + from unittest.mock import ( + AsyncMock, + ) + + provider._ws = AsyncMock() + + +@pytest.mark.asyncio +@pytest.mark.skipif( + # TODO: remove when python 3.7 is no longer supported in web3.py + # python 3.7 is already sunset so this feels like a reasonable tradeoff + sys.version_info < (3, 8), + reason="Uses AsyncMock, not supported by python 3.7", +) +async def test_async_make_request_caches_all_undesired_responses_and_returns_desired(): + provider = WebsocketProviderV2("ws://mocked") + + method_under_test = provider.make_request + + _mock_ws(provider) + undesired_responses_count = 10 + ws_recv_responses = [ + to_bytes( + text=json.dumps( + { + "jsonrpc": "2.0", + "method": "eth_subscription", + "params": {"subscription": "0x1", "result": f"0x{i}"}, + } + ) + ) + for i in range(0, undesired_responses_count) + ] + # The first request we make should have an id of `0`, expect the response to match + # that id. Append it as the last response in the list. + ws_recv_responses.append(b'{"jsonrpc": "2.0", "id":0, "result": "0x1337"}') + provider._ws.recv.side_effect = ws_recv_responses + + response = await method_under_test(RPCEndpoint("some_method"), ["desired_params"]) + assert response == json.loads(ws_recv_responses.pop()) # pop the expected response + + assert ( + len(provider._request_processor._raw_response_cache) + == len(ws_recv_responses) + == undesired_responses_count + ) + + for ( + _cache_key, + cached_response, + ) in provider._request_processor._raw_response_cache.items(): + # assert all cached responses are in the list of responses we received + assert to_bytes(text=json.dumps(cached_response)) in ws_recv_responses + + +@pytest.mark.asyncio +@pytest.mark.skipif( + # TODO: remove when python 3.7 is no longer supported in web3.py + # python 3.7 is already sunset so this feels like a reasonable tradeoff + sys.version_info < (3, 8), + reason="Uses AsyncMock, not supported by python 3.7", +) +async def test_async_make_request_returns_cached_response_with_no_recv_if_cached(): + provider = WebsocketProviderV2("ws://mocked") + + method_under_test = provider.make_request + + _mock_ws(provider) + + # cache the response, so we should get it immediately & should never call `recv()` + desired_response = {"jsonrpc": "2.0", "id": 0, "result": "0x1337"} + await provider._request_processor.cache_raw_response(desired_response) + + response = await method_under_test(RPCEndpoint("some_method"), ["desired_params"]) + assert response == desired_response + + assert len(provider._request_processor._raw_response_cache) == 0 + assert not provider._ws.recv.called # type: ignore + + +@pytest.mark.asyncio +@pytest.mark.skipif( + # TODO: remove when python 3.7 is no longer supported in web3.py + # python 3.7 is already sunset so this feels like a reasonable tradeoff + sys.version_info < (3, 8), + reason="Uses AsyncMock, not supported by python 3.7", +) +async def test_async_make_request_times_out_of_while_loop_looking_for_response(): + provider = WebsocketProviderV2("ws://mocked", call_timeout=0.1) + + method_under_test = provider.make_request + + _mock_ws(provider) + provider._ws.recv.side_effect = lambda *args, **kwargs: b'{"jsonrpc": "2.0"}' + + with pytest.raises( + TimeExhausted, + match="Timed out waiting for response with request id `0` after 0.1 seconds.", + ): + await method_under_test(RPCEndpoint("some_method"), ["desired_params"]) diff --git a/tests/integration/go_ethereum/test_goethereum_ws_v2/test_async_generator_w3.py b/tests/integration/go_ethereum/test_goethereum_ws_v2/test_async_ctx_manager_w3.py similarity index 89% rename from tests/integration/go_ethereum/test_goethereum_ws_v2/test_async_generator_w3.py rename to tests/integration/go_ethereum/test_goethereum_ws_v2/test_async_ctx_manager_w3.py index 01d3bace3a..d9dbd0a3af 100644 --- a/tests/integration/go_ethereum/test_goethereum_ws_v2/test_async_generator_w3.py +++ b/tests/integration/go_ethereum/test_goethereum_ws_v2/test_async_ctx_manager_w3.py @@ -12,6 +12,9 @@ from web3._utils.module_testing.go_ethereum_personal_module import ( GoEthereumAsyncPersonalModuleTest, ) +from web3._utils.module_testing.persistent_connection_provider import ( + PersistentConnectionProviderTest, +) from ..common import ( GoEthereumAsyncEthModuleTest, @@ -25,6 +28,8 @@ @pytest_asyncio.fixture(scope="module") async def async_w3(geth_process, endpoint_uri): await wait_for_aiohttp(endpoint_uri) + + # async context manager pattern async with AsyncWeb3.persistent_websocket( WebsocketProviderV2(endpoint_uri, call_timeout=30) ) as w3: @@ -57,6 +62,10 @@ async def test_admin_start_stop_ws(self, async_w3: "AsyncWeb3") -> None: await super().test_admin_start_stop_ws(async_w3) +class TestPersistentConnectionProviderTest(PersistentConnectionProviderTest): + pass + + class TestGoEthereumAsyncEthModuleTest(GoEthereumAsyncEthModuleTest): pass diff --git a/tests/integration/go_ethereum/test_goethereum_ws_v2/test_async_iterator_w3.py b/tests/integration/go_ethereum/test_goethereum_ws_v2/test_async_iterator_w3.py index fc321fc455..6d1907c7df 100644 --- a/tests/integration/go_ethereum/test_goethereum_ws_v2/test_async_iterator_w3.py +++ b/tests/integration/go_ethereum/test_goethereum_ws_v2/test_async_iterator_w3.py @@ -12,6 +12,9 @@ from web3._utils.module_testing.go_ethereum_personal_module import ( GoEthereumAsyncPersonalModuleTest, ) +from web3._utils.module_testing.persistent_connection_provider import ( + PersistentConnectionProviderTest, +) from ..common import ( GoEthereumAsyncEthModuleTest, @@ -25,6 +28,8 @@ @pytest_asyncio.fixture(scope="module") async def async_w3(geth_process, endpoint_uri): await wait_for_aiohttp(endpoint_uri) + + # async iterator pattern async for w3 in AsyncWeb3.persistent_websocket( WebsocketProviderV2(endpoint_uri, call_timeout=30) ): @@ -57,6 +62,10 @@ async def test_admin_start_stop_ws(self, async_w3: "AsyncWeb3") -> None: await super().test_admin_start_stop_ws(async_w3) +class TestPersistentConnectionProviderTest(PersistentConnectionProviderTest): + pass + + class TestGoEthereumAsyncEthModuleTest(GoEthereumAsyncEthModuleTest): pass diff --git a/web3/_utils/method_formatters.py b/web3/_utils/method_formatters.py index 630cc00e1a..e334adf465 100644 --- a/web3/_utils/method_formatters.py +++ b/web3/_utils/method_formatters.py @@ -643,13 +643,13 @@ def subscription_formatter(value: Any) -> Union[HexBytes, HexStr, Dict[str, Any] result_formatter = block_formatter elif either_set_is_a_subset( - result_key_set, set(LOG_ENTRY_FORMATTERS.keys()), percentage=90 + result_key_set, set(LOG_ENTRY_FORMATTERS.keys()), percentage=75 ): # logs result_formatter = log_entry_formatter elif either_set_is_a_subset( - result_key_set, set(TRANSACTION_RESULT_FORMATTERS.keys()), percentage=90 + result_key_set, set(TRANSACTION_RESULT_FORMATTERS.keys()), percentage=75 ): # newPendingTransactions, full transactions result_formatter = transaction_result_formatter @@ -663,7 +663,7 @@ def subscription_formatter(value: Any) -> Union[HexBytes, HexStr, Dict[str, Any] elif either_set_is_a_subset( result_key_set, set(SYNCING_FORMATTERS.keys()), - percentage=90, + percentage=75, ): # syncing response object result_formatter = syncing_formatter diff --git a/web3/_utils/module_testing/persistent_connection_provider.py b/web3/_utils/module_testing/persistent_connection_provider.py new file mode 100644 index 0000000000..08ee7c307d --- /dev/null +++ b/web3/_utils/module_testing/persistent_connection_provider.py @@ -0,0 +1,362 @@ +import json +import pytest +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Tuple, + cast, +) + +from eth_utils import ( + is_hexstr, + to_bytes, +) +from hexbytes import ( + HexBytes, +) + +from web3.datastructures import ( + AttributeDict, +) +from web3.middleware import ( + async_geth_poa_middleware, +) +from web3.types import ( + FormattedEthSubscriptionResponse, +) + +if TYPE_CHECKING: + from web3.main import ( + _PersistentConnectionWeb3, + ) + + +def _mocked_recv(sub_id: str, ws_subscription_response: Dict[str, Any]) -> bytes: + # Must be same subscription id so we can know how to parse the message. + # We don't have this information when mocking the response. + ws_subscription_response["params"]["subscription"] = sub_id + return to_bytes(text=json.dumps(ws_subscription_response)) + + +class PersistentConnectionProviderTest: + @pytest.mark.asyncio + @pytest.mark.parametrize( + "subscription_params,ws_subscription_response,expected_formatted_result", + ( + ( + ("newHeads",), + { + "jsonrpc": "2.0", + "method": "eth_subscription", + "params": { + "subscription": "THIS_WILL_BE_REPLACED_IN_THE_TEST", + "result": { + "number": "0x539", + "hash": "0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e", # noqa: E501 + "parentHash": "0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e", # noqa: E501 + "sha3Uncles": "0x1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347", # noqa: E501 + "logsBloom": "0x00", + "transactionsRoot": "0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988", # noqa: E501 + "stateRoot": "0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988", # noqa: E501 + "receiptsRoot": "0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988", # noqa: E501 + "miner": "0x0000000000000000000000000000000000000000", + "difficulty": "0x0", + "extraData": "0x496c6c756d696e61746520446d6f63726174697a6520447374726962757465", # noqa: E501 + "gasLimit": "0x1c9c380", + "gasUsed": "0xd1ce44", + "timestamp": "0x539", + "baseFeePerGas": "0x26f93fef9", + "withdrawalsRoot": "0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988", # noqa: E501 + "nonce": "0x0000000000000000", + "mixHash": "0x73e9e036ec894047f29954571d4b6d9e8717de7304269c263cbf150caa4e0768", # noqa: E501 + }, + }, + }, + AttributeDict( + { + "number": 1337, + "hash": HexBytes( + "0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e" # noqa: E501 + ), + "parentHash": HexBytes( + "0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e" # noqa: E501 + ), + "sha3Uncles": HexBytes( + "0x1dcc4de8dec75d7aab85b567b6ccd41ad312451b948a7413f0a142fd40d49347" # noqa: E501 + ), + "logsBloom": HexBytes("0x00"), + "transactionsRoot": HexBytes( + "0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988" # noqa: E501 + ), + "stateRoot": HexBytes( + "0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988" # noqa: E501 + ), + "receiptsRoot": HexBytes( + "0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988" # noqa: E501 + ), + "miner": "0x0000000000000000000000000000000000000000", + "difficulty": 0, + "extraData": HexBytes( + "0x496c6c756d696e61746520446d6f63726174697a6520447374726962757465" # noqa: E501 + ), + "gasLimit": 30000000, + "gasUsed": 13749828, + "timestamp": 1337, + "baseFeePerGas": 10461904633, + "withdrawalsRoot": HexBytes( + "0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988" # noqa: E501 + ), + "nonce": HexBytes("0x0000000000000000"), + "mixHash": HexBytes( + "0x73e9e036ec894047f29954571d4b6d9e8717de7304269c263cbf150caa4e0768" # noqa: E501 + ), + } + ), + ), + ( + ("newPendingTransactions", True), + { + "jsonrpc": "2.0", + "method": "eth_subscription", + "params": { + "subscription": "THIS_WILL_BE_REPLACED_IN_THE_TEST", + "result": { + "blockHash": None, + "blockNumber": None, + "from": "0x0000000000000000000000000000000000000000", + "gas": "0xf2f4", + "gasPrice": "0x29035f36f", + "maxFeePerGas": "0x29035f36f", + "maxPriorityFeePerGas": "0x3b9aca00", + "hash": "0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e", # noqa: E501 + "input": "0x00", + "nonce": "0x2013", + "to": "0x0000000000000000000000000000000000000000", + "transactionIndex": None, + "value": "0x0", + "type": "0x2", + "accessList": [], + "chainId": "0x1", + "v": "0x1", + "r": "0x3c144a7c00ed3118d55445cd5be2ae4620ca377f7c685e9c5f3687671d4dece1", # noqa: E501 + "s": "0x284de67cbf75fec8a9edb368dee3a37cf6faba87f0af4413b2f869ebfa87d002", # noqa: E501 + "yParity": "0x1", + }, + }, + }, + AttributeDict( + { + "blockHash": None, + "blockNumber": None, + "from": "0x0000000000000000000000000000000000000000", + "gas": 62196, + "gasPrice": 11009389423, + "maxFeePerGas": 11009389423, + "maxPriorityFeePerGas": 1000000000, + "hash": HexBytes( + "0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e" # noqa: E501 + ), + "input": HexBytes("0x00"), + "nonce": 8211, + "to": "0x0000000000000000000000000000000000000000", + "transactionIndex": None, + "value": 0, + "type": 2, + "accessList": [], + "chainId": 1, + "v": 1, + "r": HexBytes( + "0x3c144a7c00ed3118d55445cd5be2ae4620ca377f7c685e9c5f3687671d4dece1" # noqa: E501 + ), + "s": HexBytes( + "0x284de67cbf75fec8a9edb368dee3a37cf6faba87f0af4413b2f869ebfa87d002" # noqa: E501 + ), + "yParity": 1, + } + ), + ), + ( + ("newPendingTransactions", False), + { + "jsonrpc": "2.0", + "method": "eth_subscription", + "params": { + "subscription": "THIS_WILL_BE_REPLACED_IN_THE_TEST", + "result": "0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e", # noqa: E501 + }, + }, + HexBytes( + "0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e" + ), + ), + ( + ("logs", {"address": "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"}), + { + "jsonrpc": "2.0", + "method": "eth_subscription", + "params": { + "subscription": "THIS_WILL_BE_REPLACED_IN_THE_TEST", + "result": { + "removed": False, + "logIndex": "0x0", + "transactionIndex": "0x0", + "transactionHash": "0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988", # noqa: E501 + "blockHash": "0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e", # noqa: E501 + "blockNumber": "0x539", + "address": "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2", + "data": "0x00", + "topics": [ + "0xe1fffdd4923d04f559f4d29e8bfc6cda04eb5b0d3c460751c2402c5c5cc9105c", # noqa: E501 + "0x00000000000000000000000016250d5630b4cf539739df2c5dacb4c659f2482d", # noqa: E501 + ], + }, + }, + }, + AttributeDict( + { + "removed": False, + "logIndex": 0, + "transactionIndex": 0, + "transactionHash": HexBytes( + "0x56260fe8298aff6d360e3a68fa855693f25dcb2708d8a7e509e8519b265d3988" # noqa: E501 + ), + "blockHash": HexBytes( + "0xb46b85928f2c2264c2bf7ad5c6d6985664f1527e744193ef990cc0d3da5afc5e" # noqa: E501 + ), + "blockNumber": 1337, + "address": "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2", + "data": HexBytes("0x00"), + "topics": [ + HexBytes( + "0xe1fffdd4923d04f559f4d29e8bfc6cda04eb5b0d3c460751c2402c5c5cc9105c" # noqa: E501 + ), + HexBytes( + "0x00000000000000000000000016250d5630b4cf539739df2c5dacb4c659f2482d" # noqa: E501 + ), + ], + } + ), + ), + ( + ("syncing",), + { + "jsonrpc": "2.0", + "method": "eth_subscription", + "params": { + "subscription": "THIS_WILL_BE_REPLACED_IN_THE_TEST", + "result": False, + }, + }, + False, + ), + ( + ("syncing",), + { + "jsonrpc": "2.0", + "method": "eth_subscription", + "params": { + "subscription": "THIS_WILL_BE_REPLACED_IN_THE_TEST", + "result": { + "isSyncing": True, + "startingBlock": "0x0", + "currentBlock": "0x4346fe", + "highestBlock": "0x434806", + }, + }, + }, + AttributeDict( + { + "isSyncing": True, + "startingBlock": 0, + "currentBlock": 4409086, + "highestBlock": 4409350, + } + ), + ), + ), + ids=[ + "newHeads", + "newPendingTransactions-FullTxs", + "newPendingTransactions-TxHashes", + "logs", + "syncing-False", + "syncing-True", + ], + ) + async def test_async_eth_subscribe_mocked( + self, + async_w3: "_PersistentConnectionWeb3", + subscription_params: Tuple[Any, ...], + ws_subscription_response: Dict[str, Any], + expected_formatted_result: Any, + ) -> None: + sub_id = await async_w3.eth.subscribe(*subscription_params) + assert is_hexstr(sub_id) + + async def _mocked_recv_coro() -> bytes: + return _mocked_recv(sub_id, ws_subscription_response) + + actual_recv_fxn = async_w3.provider._ws.recv + async_w3.provider._ws.__setattr__( + "recv", + _mocked_recv_coro, + ) + + async for msg in async_w3.ws.listen_to_websocket(): + response = cast(FormattedEthSubscriptionResponse, msg) + assert response["subscription"] == sub_id + assert response["result"] == expected_formatted_result + + # only testing one message, so break here + break + + # reset the mocked recv + async_w3.provider._ws.__setattr__("recv", actual_recv_fxn) + + @pytest.mark.asyncio + async def test_async_geth_poa_middleware_on_eth_subscription( + self, + async_w3: "_PersistentConnectionWeb3", + ) -> None: + async_w3.middleware_onion.inject( + async_geth_poa_middleware, "poa_middleware", layer=0 + ) + + sub_id = await async_w3.eth.subscribe("newHeads") + assert is_hexstr(sub_id) + + async def _mocked_recv_coro() -> bytes: + return _mocked_recv( + sub_id, + { + "jsonrpc": "2.0", + "method": "eth_subscription", + "params": { + "subscription": sub_id, + "result": { + "extraData": f"0x{'00' * 100}", + }, + }, + }, + ) + + actual_recv_fxn = async_w3.provider._ws.recv + async_w3.provider._ws.__setattr__( + "recv", + _mocked_recv_coro, + ) + + async for msg in async_w3.ws.listen_to_websocket(): + response = cast(FormattedEthSubscriptionResponse, msg) + assert response.keys() == {"subscription", "result"} + assert response["subscription"] == sub_id + assert response["result"]["proofOfAuthorityData"] == HexBytes( # type: ignore # noqa: E501 + f"0x{'00' * 100}" + ) + + break + + # reset the mocked recv + async_w3.provider._ws.__setattr__("recv", actual_recv_fxn) + async_w3.middleware_onion.remove("poa_middleware") diff --git a/web3/main.py b/web3/main.py index b97aea0c51..7dc5ac9a6c 100644 --- a/web3/main.py +++ b/web3/main.py @@ -543,6 +543,9 @@ def __init__( # async for w3 in w3.persistent_websocket(provider) async def __aiter__(self) -> AsyncIterator[Self]: + if not await self.provider.is_connected(): + await self.provider.connect() + while True: try: yield self diff --git a/web3/manager.py b/web3/manager.py index a0dc46ac96..0b740c0b36 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -22,12 +22,12 @@ ConnectionClosedOK, ) -from web3._utils.async_caching import ( - async_lock, -) from web3._utils.caching import ( generate_cache_key, ) +from web3._utils.compat import ( + Self, +) from web3.datastructures import ( NamedElementOnion, ) @@ -354,7 +354,7 @@ async def ws_send(self, method: RPCEndpoint, params: Any) -> RPCResponse: async def ws_recv(self) -> Any: return await self._ws_recv_stream().__anext__() - def persistent_recv_stream(self) -> "_AsyncPersistentRecvStream": + def _persistent_recv_stream(self) -> "_AsyncPersistentRecvStream": return _AsyncPersistentRecvStream(self) async def _ws_recv_stream(self) -> AsyncGenerator[RPCResponse, None]: @@ -366,21 +366,17 @@ async def _ws_recv_stream(self) -> AsyncGenerator[RPCResponse, None]: cached_responses = len(self._request_processor._raw_response_cache.items()) if cached_responses > 0: - async with async_lock( - self._provider._thread_pool, - self._provider._lock, - ): - self._provider.logger.debug( - f"{cached_responses} cached response(s) in raw response cache. " - f"Processing as FIFO ahead of any new responses from open " - f"socket connection." - ) - for ( - cache_key, - cached_response, - ) in self._request_processor._raw_response_cache.items(): - self._request_processor.pop_raw_response(cache_key) - yield await self._process_ws_response(cached_response) + self._provider.logger.debug( + f"{cached_responses} cached response(s) in raw response cache. " + f"Processing as FIFO ahead of any new responses from open " + f"socket connection." + ) + for ( + cache_key, + cached_response, + ) in self._request_processor._raw_response_cache.items(): + await self._request_processor.pop_raw_response(cache_key) + yield await self._process_ws_response(cached_response) else: response = await self._provider._ws_recv() yield await self._process_ws_response(response) @@ -440,10 +436,11 @@ def __init__(self, manager: RequestManager, *args: Any, **kwargs: Any) -> None: self.manager = manager super().__init__(*args, **kwargs) - def __aiter__(self) -> AsyncGenerator[RPCResponse, None]: - while True: - try: - # solely listen to the stream, no request id necessary - return self.manager._ws_recv_stream() - except ConnectionClosedOK: - pass + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> RPCResponse: + try: + return await self.manager.ws_recv() + except ConnectionClosedOK: + raise StopAsyncIteration diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index ef84050474..0a26ff3c3f 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -16,6 +16,7 @@ from web3.types import ( AsyncMiddleware, AsyncMiddlewareCoroutine, + EthSubscriptionParams, Formatters, FormattersDict, Literal, @@ -48,17 +49,36 @@ def _apply_response_formatters( response: RPCResponse, ) -> RPCResponse: def _format_response( - response_type: Literal["result", "error"], + response_type: Literal["result", "error", "params"], method_response_formatter: Callable[..., Any], ) -> RPCResponse: appropriate_response = response[response_type] - return assoc( - response, response_type, method_response_formatter(appropriate_response) - ) + if response_type == "params": + appropriate_response = cast(EthSubscriptionParams, response[response_type]) + return assoc( + response, + response_type, + assoc( + response["params"], + "result", + method_response_formatter(appropriate_response["result"]), + ), + ) + else: + return assoc( + response, response_type, method_response_formatter(appropriate_response) + ) if response.get("result") is not None and method in result_formatters: return _format_response("result", result_formatters[method]) + elif ( + # eth_subscription responses + response.get("params") is not None + and response["params"].get("result") is not None + and method in result_formatters + ): + return _format_response("params", result_formatters[method]) elif "error" in response and method in error_formatters: return _format_response("error", error_formatters[method]) else: diff --git a/web3/middleware/geth_poa.py b/web3/middleware/geth_poa.py index 9cd7ba93f1..dd80431991 100644 --- a/web3/middleware/geth_poa.py +++ b/web3/middleware/geth_poa.py @@ -4,6 +4,9 @@ Callable, ) +from eth_utils import ( + is_dict, +) from eth_utils.curried import ( apply_formatter_if, apply_formatters_to_dict, @@ -68,6 +71,11 @@ async def async_geth_poa_middleware( result_formatters={ RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup), RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup), + RPC.eth_subscribe: apply_formatter_if( + is_not_null, + # original call to eth_subscribe returns a string, needs a dict check + apply_formatter_if(is_dict, geth_poa_cleanup), + ), }, ) return await middleware(make_request, w3) diff --git a/web3/providers/persistent.py b/web3/providers/persistent.py index 74896cfea8..9f033b3ab3 100644 --- a/web3/providers/persistent.py +++ b/web3/providers/persistent.py @@ -1,11 +1,7 @@ from abc import ( ABC, ) -from concurrent.futures import ( - ThreadPoolExecutor, -) import logging -import threading from typing import ( Optional, ) @@ -33,14 +29,12 @@ class PersistentConnectionProvider(AsyncJSONBaseProvider, ABC): _ws: Optional[WebSocketClientProtocol] = None _request_processor: RequestProcessor - _thread_pool: ThreadPoolExecutor = ThreadPoolExecutor() - _lock: threading.Lock = threading.Lock() def __init__( self, endpoint_uri: str, request_cache_size: int = 100, - call_timeout: int = DEFAULT_PERSISTENT_CONNECTION_TIMEOUT, + call_timeout: float = DEFAULT_PERSISTENT_CONNECTION_TIMEOUT, ) -> None: super().__init__() self.endpoint_uri = endpoint_uri diff --git a/web3/providers/websocket/request_processor.py b/web3/providers/websocket/request_processor.py index 1102d15267..7be87f7fd0 100644 --- a/web3/providers/websocket/request_processor.py +++ b/web3/providers/websocket/request_processor.py @@ -1,3 +1,4 @@ +import asyncio from copy import ( copy, ) @@ -32,6 +33,8 @@ class RequestProcessor: _request_information_cache: SimpleCache + _raw_response_cache: SimpleCache + _raw_response_cache_lock: asyncio.Lock = asyncio.Lock() def __init__( self, @@ -166,7 +169,7 @@ def append_middleware_response_processor( # raw response cache - def cache_raw_response(self, raw_response: Any) -> None: + async def cache_raw_response(self, raw_response: Any) -> None: # get id or generate a uuid if not present (i.e. subscription response) response_id = raw_response.get("id", f"sub-{uuid4()}") cache_key = generate_cache_key(response_id) @@ -174,15 +177,18 @@ def cache_raw_response(self, raw_response: Any) -> None: f"Caching raw response:\n response_id={response_id},\n" f" cache_key={cache_key},\n raw_response={raw_response}" ) - self._raw_response_cache.cache(cache_key, raw_response) + async with self._raw_response_cache_lock: + self._raw_response_cache.cache(cache_key, raw_response) - def pop_raw_response(self, cache_key: str) -> Any: - raw_response = self._raw_response_cache.pop(cache_key) + async def pop_raw_response(self, cache_key: str) -> Any: + async with self._raw_response_cache_lock: + raw_response = self._raw_response_cache.pop(cache_key) self._provider.logger.debug( f"Cached response processed and popped from cache:\n" f" cache_key={cache_key},\n" f" raw_response={raw_response}" ) + return raw_response # request processor class methods diff --git a/web3/providers/websocket/websocket_connection.py b/web3/providers/websocket/websocket_connection.py index a8f4e70eb7..2b3cb883a6 100644 --- a/web3/providers/websocket/websocket_connection.py +++ b/web3/providers/websocket/websocket_connection.py @@ -33,4 +33,4 @@ async def recv(self) -> Any: return await self._w3.manager.ws_recv() def listen_to_websocket(self) -> "_AsyncPersistentRecvStream": - return self._w3.manager.persistent_recv_stream() + return self._w3.manager._persistent_recv_stream() diff --git a/web3/providers/websocket/websocket_v2.py b/web3/providers/websocket/websocket_v2.py index 6bc513e990..64a1e788b5 100644 --- a/web3/providers/websocket/websocket_v2.py +++ b/web3/providers/websocket/websocket_v2.py @@ -22,17 +22,16 @@ WebSocketException, ) -from web3._utils.async_caching import ( - async_lock, -) from web3._utils.caching import ( generate_cache_key, ) from web3.exceptions import ( ProviderConnectionError, + TimeExhausted, Web3ValidationError, ) from web3.providers.persistent import ( + DEFAULT_PERSISTENT_CONNECTION_TIMEOUT, PersistentConnectionProvider, ) from web3.types import ( @@ -67,7 +66,7 @@ def __init__( self, endpoint_uri: Optional[Union[URI, str]] = None, websocket_kwargs: Optional[Dict[str, Any]] = None, - call_timeout: Optional[int] = None, + call_timeout: Optional[float] = DEFAULT_PERSISTENT_CONNECTION_TIMEOUT, ) -> None: self.endpoint_uri = URI(endpoint_uri) if self.endpoint_uri is None: @@ -158,40 +157,55 @@ async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: await asyncio.wait_for(self._ws.send(request_data), timeout=self.call_timeout) current_request_id = json.loads(request_data)["id"] + request_cache_key = generate_cache_key(current_request_id) - response = await self._ws_recv() - response_id = response.get("id") + if request_cache_key in self._request_processor._raw_response_cache: + # if response is already cached, pop it from cache + response = await self._request_processor.pop_raw_response(request_cache_key) + else: + # else, wait for the desired response, caching all others along the way + response = await self._get_response_for_request_id(current_request_id) - if response_id != current_request_id: - request_cache_key = generate_cache_key(current_request_id) - if request_cache_key in self._request_processor._raw_response_cache: - async with async_lock(self._thread_pool, self._lock): - # if response is already cached, pop it from the cache - response = self._request_processor.pop_raw_response( - request_cache_key - ) - else: - async with async_lock(self._thread_pool, self._lock): - # cache response - self._request_processor.cache_raw_response(response) - response = await asyncio.wait_for( - self._get_response_for_request_id(current_request_id), - self.call_timeout, - ) return response async def _get_response_for_request_id(self, request_id: RPCId) -> RPCResponse: - response = await self._ws_recv() - response_id = response.get("id") - - while response_id != request_id: - response = await self._ws_recv() - response_id = response.get("id") - if response_id != request_id: - self._request_processor.cache_raw_response( - response, - ) - return response + async def _match_response_id_to_request_id() -> RPCResponse: + response_id = None + response = None + while response_id != request_id: + response = await self._ws_recv() + response_id = response.get("id") + + if response_id == request_id: + break + else: + # cache all responses that are not the desired response + await self._request_processor.cache_raw_response( + response, + ) + await asyncio.sleep(0.1) + + return response + + try: + # Enters a while loop, looking for a response id match to the request id. + # If the provider does not give responses with matching ids, this will + # hang forever. The JSON-RPC spec requires that providers respond with + # the same id that was sent in the request, but we need to handle these + # "bad" cases somewhat gracefully. + timeout = ( + self.call_timeout + if self.call_timeout and self.call_timeout <= 20 + else 20 + ) + return await asyncio.wait_for(_match_response_id_to_request_id(), timeout) + except asyncio.TimeoutError: + raise TimeExhausted( + f"Timed out waiting for response with request id `{request_id}` after " + f"{self.call_timeout} seconds. This is likely due to the provider not " + "returning a response with the same id that was sent in the request, " + "which is required by the JSON-RPC spec." + ) async def _ws_recv(self) -> RPCResponse: return json.loads( diff --git a/web3/types.py b/web3/types.py index ef154c06d9..b98f81785c 100644 --- a/web3/types.py +++ b/web3/types.py @@ -281,6 +281,14 @@ class GethSyncingSubscriptionResponse(SubscriptionResponse): result: GethSyncingSubscriptionResult +EthSubscriptionParams = Union[ + BlockTypeSubscriptionResponse, + TransactionTypeSubscriptionResponse, + LogsSubscriptionResponse, + SyncingSubscriptionResponse, + GethSyncingSubscriptionResponse, +] + RPCId = Optional[Union[int, str]] @@ -292,12 +300,13 @@ class RPCResponse(TypedDict, total=False): # eth_subscribe method: Literal["eth_subscription"] - params: Union[ - BlockTypeSubscriptionResponse, - TransactionTypeSubscriptionResponse, - LogsSubscriptionResponse, - SyncingSubscriptionResponse, - GethSyncingSubscriptionResponse, + params: EthSubscriptionParams + + +class FormattedEthSubscriptionResponse(TypedDict): + subscription: HexStr + result: Union[ + BlockData, TxData, LogReceipt, SyncProgress, GethSyncingSubscriptionResult ]