diff --git a/.gitignore b/.gitignore index ed1b949468..87e55458ea 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,7 @@ docs/web3.gas_strategies.rst docs/web3.middleware.rst docs/web3.providers.eth_tester.rst docs/web3.providers.rst +docs/web3.providers.rpc.rst docs/web3.providers.websocket.rst docs/web3.rst docs/web3.scripts.release.rst diff --git a/docs/conf.py b/docs/conf.py index 78f0cf94c7..210918e4f6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -86,6 +86,7 @@ "web3.gas_strategies.rst", "web3.middleware.rst", "web3.providers.rst", + "web3.providers.rpc.rst", "web3.providers.websocket.rst", "web3.providers.eth_tester.rst", "web3.scripts.*", @@ -224,7 +225,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ("index", "Populus.tex", "Populus Documentation", "The Ethereum Foundation", "manual"), + ( + "index", + "Populus.tex", + "Populus Documentation", + "The Ethereum Foundation", + "manual", + ), ] # The name of an image file (relative to this directory) to place at the top of diff --git a/docs/examples.rst b/docs/examples.rst index 9260d81b8d..97099f3925 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -688,8 +688,8 @@ Inject the middleware into the middleware onion .. code-block:: python - from web3.middleware import geth_poa_middleware - w3.middleware_onion.inject(geth_poa_middleware, layer=0) + from web3.middleware import extradata_to_poa_middleware + w3.middleware_onion.inject(extradata_to_poa_middleware, layer=0) Just remember that you have to sign all transactions locally, as infura does not handle any keys from your wallet ( refer to `this`_ ) diff --git a/docs/middleware.rst b/docs/middleware.rst index baf3869982..aa767bca0f 100644 --- a/docs/middleware.rst +++ b/docs/middleware.rst @@ -104,18 +104,6 @@ Buffered Gas Estimate ``min(w3.eth.estimate_gas + gas_buffer, gas_limit)`` where the gas_buffer default is 100,000 -HTTPRequestRetry -~~~~~~~~~~~~~~~~~~ - -.. py:method:: web3.middleware.http_retry_request_middleware - web3.middleware.async_http_retry_request_middleware - - This middleware is a default specifically for HTTPProvider that retries failed - requests that return the following errors: ``ConnectionError``, ``HTTPError``, ``Timeout``, - ``TooManyRedirects``. Additionally there is a whitelist that only allows certain - methods to be retried in order to not resend transactions, excluded methods are: - ``eth_sendTransaction``, ``personal_signAndSendTransaction``, ``personal_sendTransaction``. - Validation ~~~~~~~~~~~~~~~~~~~~~~~~ @@ -446,16 +434,15 @@ Time-based Cache Middleware Proof of Authority ~~~~~~~~~~~~~~~~~~ -.. py:method:: web3.middleware.geth_poa_middleware - web3.middleware.async_geth_poa_middleware +.. py:class:: web3.middleware.extradata_to_poa_middleware .. note:: It's important to inject the middleware at the 0th layer of the middleware onion: - ``w3.middleware_onion.inject(geth_poa_middleware, layer=0)`` + ``w3.middleware_onion.inject(extradata_to_poa_middleware, layer=0)`` -The ``geth_poa_middleware`` is required to connect to ``geth --dev`` or the Goerli -public network. It may also be needed for other EVM compatible blockchains like Polygon -or BNB Chain (Binance Smart Chain). +The ``extradata_to_poa_middleware`` is required to connect to ``geth --dev`` and may +also be needed for other EVM compatible blockchains like Polygon or +BNB Chain (Binance Smart Chain). If the middleware is not injected at the 0th layer of the middleware onion, you may get errors like the example below when interacting with your EVM node. @@ -468,7 +455,8 @@ errors like the example below when interacting with your EVM node. for more details. The full extraData is: HexBytes('...') -The easiest way to connect to a default ``geth --dev`` instance which loads the middleware is: +The easiest way to connect to a default ``geth --dev`` instance which loads the +middleware is: .. code-block:: python @@ -489,25 +477,26 @@ unique IPC location and loads the middleware: # connect to the IPC location started with 'geth --dev --datadir ~/mynode' >>> w3 = Web3(IPCProvider('~/mynode/geth.ipc')) - >>> from web3.middleware import geth_poa_middleware + >>> from web3.middleware import extradata_to_poa_middleware # inject the poa compatibility middleware to the innermost layer (0th layer) - >>> w3.middleware_onion.inject(geth_poa_middleware, layer=0) + >>> w3.middleware_onion.inject(extradata_to_poa_middleware, layer=0) # confirm that the connection succeeded >>> w3.client_version 'Geth/v1.7.3-stable-4bb3c89d/linux-amd64/go1.9' -Why is ``geth_poa_middleware`` necessary? -''''''''''''''''''''''''''''''''''''''''' +Why is ``extradata_to_poa_middleware`` necessary? +''''''''''''''''''''''''''''''''''''''''''''''''' There is no strong community consensus on a single Proof-of-Authority (PoA) standard yet. Some nodes have successful experiments running though. One is go-ethereum (geth), which uses a prototype PoA for its development mode and the Goerli test network. Unfortunately, it does deviate from the yellow paper specification, which constrains the -``extraData`` field in each block to a maximum of 32-bytes. Geth's PoA uses more than -32 bytes, so this middleware modifies the block data a bit before returning it. +``extraData`` field in each block to a maximum of 32-bytes. Geth is one such example +where PoA uses more than 32 bytes, so this middleware modifies the block data a bit +before returning it. .. _local-filter: diff --git a/docs/providers.rst b/docs/providers.rst index a87a68fca0..6f983ab769 100644 --- a/docs/providers.rst +++ b/docs/providers.rst @@ -467,7 +467,6 @@ AsyncHTTPProvider - :meth:`Attribute Dict Middleware ` - :meth:`Buffered Gas Estimate Middleware ` - :meth:`Gas Price Strategy Middleware ` - - :meth:`Geth POA Middleware ` - :meth:`Local Filter Middleware ` - :meth:`Simple Cache Middleware ` - :meth:`Stalecheck Middleware ` diff --git a/ens/async_ens.py b/ens/async_ens.py index 48a2646a12..f476df4dad 100644 --- a/ens/async_ens.py +++ b/ens/async_ens.py @@ -71,12 +71,14 @@ AsyncContractFunction, ) from web3.main import AsyncWeb3 # noqa: F401 + from web3.middleware.base import ( # noqa: F401 + Middleware, + ) from web3.providers import ( # noqa: F401 AsyncBaseProvider, BaseProvider, ) from web3.types import ( # noqa: F401 - AsyncMiddleware, TxParams, ) @@ -98,7 +100,7 @@ def __init__( self, provider: "AsyncBaseProvider" = cast("AsyncBaseProvider", default), addr: ChecksumAddress = None, - middlewares: Optional[Sequence[Tuple["AsyncMiddleware", str]]] = None, + middlewares: Optional[Sequence[Tuple["Middleware", str]]] = None, ) -> None: """ :param provider: a single provider used to connect to Ethereum diff --git a/ens/ens.py b/ens/ens.py index 97b24a2896..fcc0ecdb4f 100644 --- a/ens/ens.py +++ b/ens/ens.py @@ -71,11 +71,13 @@ Contract, ContractFunction, ) + from web3.middleware.base import ( # noqa: F401 + Middleware, + ) from web3.providers import ( # noqa: F401 BaseProvider, ) from web3.types import ( # noqa: F401 - Middleware, TxParams, ) diff --git a/ens/utils.py b/ens/utils.py index fd5e81fd44..314beac68c 100644 --- a/ens/utils.py +++ b/ens/utils.py @@ -5,7 +5,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Collection, Dict, List, @@ -60,15 +59,15 @@ AsyncWeb3, Web3 as _Web3, ) + from web3.middleware.base import ( + Middleware, + ) from web3.providers import ( # noqa: F401 AsyncBaseProvider, BaseProvider, ) from web3.types import ( # noqa: F401 ABIFunction, - AsyncMiddleware, - Middleware, - RPCEndpoint, ) @@ -104,13 +103,14 @@ def customize_web3(w3: "_Web3") -> "_Web3": make_stalecheck_middleware, ) - if w3.middleware_onion.get("name_to_address"): - w3.middleware_onion.remove("name_to_address") + if w3.middleware_onion.get("ens_name_to_address"): + w3.middleware_onion.remove("ens_name_to_address") if not w3.middleware_onion.get("stalecheck"): - w3.middleware_onion.add( - make_stalecheck_middleware(ACCEPTABLE_STALE_HOURS * 3600), name="stalecheck" + stalecheck_middleware = make_stalecheck_middleware( + ACCEPTABLE_STALE_HOURS * 3600 ) + w3.middleware_onion.add(stalecheck_middleware, name="stalecheck") return w3 @@ -299,7 +299,7 @@ def get_abi_output_types(abi: "ABIFunction") -> List[str]: def init_async_web3( provider: "AsyncBaseProvider" = cast("AsyncBaseProvider", default), - middlewares: Optional[Sequence[Tuple["AsyncMiddleware", str]]] = (), + middlewares: Optional[Sequence[Tuple["Middleware", str]]] = (), ) -> "AsyncWeb3": from web3 import ( AsyncWeb3 as AsyncWeb3Main, @@ -307,14 +307,19 @@ def init_async_web3( from web3.eth import ( AsyncEth as AsyncEthMain, ) + from web3.middleware import ( + make_stalecheck_middleware, + ) middlewares = list(middlewares) for i, (middleware, name) in enumerate(middlewares): - if name == "name_to_address": + if name == "ens_name_to_address": middlewares.pop(i) if "stalecheck" not in (name for mw, name in middlewares): - middlewares.append((_async_ens_stalecheck_middleware, "stalecheck")) + middlewares.append( + (make_stalecheck_middleware(ACCEPTABLE_STALE_HOURS * 3600), "stalecheck") + ) if provider is default: async_w3 = AsyncWeb3Main( @@ -329,14 +334,3 @@ def init_async_web3( ) return async_w3 - - -async def _async_ens_stalecheck_middleware( - make_request: Callable[["RPCEndpoint", Any], Any], w3: "AsyncWeb3" -) -> "Middleware": - from web3.middleware import ( - async_make_stalecheck_middleware, - ) - - middleware = await async_make_stalecheck_middleware(ACCEPTABLE_STALE_HOURS * 3600) - return await middleware(make_request, w3) diff --git a/newsfragments/3169.breaking.rst b/newsfragments/3169.breaking.rst new file mode 100644 index 0000000000..60fad67045 --- /dev/null +++ b/newsfragments/3169.breaking.rst @@ -0,0 +1 @@ +Refactor the middleware setup so that request processors and response processors are separated. This will allow for more flexibility in the future and aid in the implementation of features such as batched requests. This PR also closes out a few outstanding issues and will be the start of the breaking changes for `web3.py` ``v7``. Review PR for a full list of changes. diff --git a/setup.py b/setup.py index 45278eccf8..264e0ac581 100644 --- a/setup.py +++ b/setup.py @@ -78,6 +78,7 @@ "jsonschema>=4.0.0", "lru-dict>=1.1.6,<1.3.0", "protobuf>=4.21.6", + "pydantic>=2.4.0", "pywin32>=223;platform_system=='Windows'", "requests>=2.16.0", "typing-extensions>=4.0.1", diff --git a/tests/conftest.py b/tests/conftest.py index 9e194dbcf8..6a4b502867 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,7 @@ import pytest +from typing import ( + Type, +) from eth_utils import ( event_signature_to_log_topic, @@ -7,10 +10,14 @@ from eth_utils.toolz import ( identity, ) +import pytest_asyncio from web3._utils.contract_sources.contract_data.emitter_contract import ( EMITTER_CONTRACT_DATA, ) +from web3._utils.module_testing.utils import ( + RequestMocker, +) from .utils import ( get_open_port, @@ -107,3 +114,11 @@ class LogTopics: @pytest.fixture(scope="session") def emitter_contract_log_topics(): return LogTopics + + +# -- mock requests -- # + + +@pytest_asyncio.fixture(scope="function") +def request_mocker() -> Type[RequestMocker]: + return RequestMocker diff --git a/tests/core/caching-utils/test_request_caching.py b/tests/core/caching-utils/test_request_caching.py new file mode 100644 index 0000000000..6307e02dfa --- /dev/null +++ b/tests/core/caching-utils/test_request_caching.py @@ -0,0 +1,236 @@ +import itertools +import pytest +import threading +import uuid + +import pytest_asyncio + +from web3 import ( + AsyncWeb3, + Web3, +) +from web3._utils.caching import ( + generate_cache_key, +) +from web3.providers import ( + AsyncBaseProvider, + BaseProvider, +) +from web3.types import ( + RPCEndpoint, +) +from web3.utils import ( + SimpleCache, +) + + +def simple_cache_return_value_a(): + _cache = SimpleCache() + _cache.cache( + generate_cache_key(f"{threading.get_ident()}:{('fake_endpoint', [1])}"), + {"result": "value-a"}, + ) + return _cache + + +@pytest.fixture +def w3(request_mocker): + _w3 = Web3(provider=BaseProvider()) + _w3.provider.cache_allowed_requests = True + _w3.provider.cacheable_requests += (RPCEndpoint("fake_endpoint"),) + with request_mocker( + _w3, + mock_results={ + "fake_endpoint": lambda *_: uuid.uuid4(), + "not_on_allowlist": lambda *_: uuid.uuid4(), + }, + ): + yield _w3 + + # clear request cache after each test + _w3.provider._request_cache.clear() + + +def test_request_caching_pulls_from_cache(w3): + w3.provider._request_cache = simple_cache_return_value_a() + assert w3.manager.request_blocking("fake_endpoint", [1]) == "value-a" + + +def test_request_caching_populates_cache(w3): + result = w3.manager.request_blocking("fake_endpoint", []) + assert w3.manager.request_blocking("fake_endpoint", []) == result + assert w3.manager.request_blocking("fake_endpoint", [1]) != result + assert len(w3.provider._request_cache.items()) == 2 + + +def test_request_caching_does_not_cache_none_responses(request_mocker): + w3 = Web3(BaseProvider()) + w3.provider.cache_allowed_requests = True + w3.provider.cacheable_requests += (RPCEndpoint("fake_endpoint"),) + + counter = itertools.count() + + def result_cb(_method, _params): + next(counter) + return None + + with request_mocker(w3, mock_results={"fake_endpoint": result_cb}): + w3.manager.request_blocking("fake_endpoint", []) + w3.manager.request_blocking("fake_endpoint", []) + + assert next(counter) == 2 + + +def test_request_caching_does_not_cache_error_responses(request_mocker): + w3 = Web3(provider=BaseProvider()) + w3.provider.cache_allowed_requests = True + w3.provider.cacheable_requests += (RPCEndpoint("fake_endpoint"),) + + with request_mocker( + w3, mock_errors={"fake_endpoint": lambda *_: {"message": f"msg-{uuid.uuid4()}"}} + ): + with pytest.raises(ValueError) as err_a: + w3.manager.request_blocking("fake_endpoint", []) + with pytest.raises(ValueError) as err_b: + w3.manager.request_blocking("fake_endpoint", []) + + assert str(err_a) != str(err_b) + assert err_a.value.args != err_b.value.args + + +def test_request_caching_does_not_cache_endpoints_not_in_allowlist(w3): + result_a = w3.manager.request_blocking("not_on_allowlist", []) + result_b = w3.manager.request_blocking("not_on_allowlist", []) + assert result_a != result_b + + +def test_caching_requests_does_not_share_state_between_providers(request_mocker): + w3_a, w3_b, w3_c = ( + Web3(provider=BaseProvider()), + Web3(provider=BaseProvider()), + Web3(provider=BaseProvider()), + ) + mock_results_a = {RPCEndpoint("eth_chainId"): 11111} + mock_results_b = {RPCEndpoint("eth_chainId"): 22222} + mock_results_c = {RPCEndpoint("eth_chainId"): 33333} + + with request_mocker(w3_a, mock_results=mock_results_a): + with request_mocker(w3_b, mock_results=mock_results_b): + with request_mocker(w3_c, mock_results=mock_results_c): + result_a = w3_a.manager.request_blocking("eth_chainId", []) + result_b = w3_b.manager.request_blocking("eth_chainId", []) + result_c = w3_c.manager.request_blocking("eth_chainId", []) + + assert result_a == 11111 + assert result_b == 22222 + assert result_c == 33333 + + +# -- async -- # + + +@pytest_asyncio.fixture +async def async_w3(request_mocker): + _async_w3 = AsyncWeb3(AsyncBaseProvider()) + _async_w3.provider.cache_allowed_requests = True + _async_w3.provider.cacheable_requests += (RPCEndpoint("fake_endpoint"),) + async with request_mocker( + _async_w3, + mock_results={ + "fake_endpoint": lambda *_: uuid.uuid4(), + "not_on_allowlist": lambda *_: uuid.uuid4(), + }, + ): + yield _async_w3 + + # clear request cache after each test + _async_w3.provider._request_cache.clear() + + +@pytest.mark.asyncio +async def test_async_request_caching_pulls_from_cache(async_w3): + async_w3.provider._request_cache = simple_cache_return_value_a() + _result = await async_w3.manager.coro_request("fake_endpoint", [1]) + assert _result == "value-a" + + +@pytest.mark.asyncio +async def test_async_request_caching_populates_cache(async_w3): + result = await async_w3.manager.coro_request("fake_endpoint", []) + + _empty_params = await async_w3.manager.coro_request("fake_endpoint", []) + _non_empty_params = await async_w3.manager.coro_request("fake_endpoint", [1]) + + assert _empty_params == result + assert _non_empty_params != result + + +@pytest.mark.asyncio +async def test_async_request_caching_does_not_cache_none_responses(request_mocker): + async_w3 = AsyncWeb3(AsyncBaseProvider()) + async_w3.provider.cache_allowed_requests = True + async_w3.provider.cacheable_requests += (RPCEndpoint("fake_endpoint"),) + + counter = itertools.count() + + def result_cb(_method, _params): + next(counter) + return None + + async with request_mocker(async_w3, mock_results={"fake_endpoint": result_cb}): + await async_w3.manager.coro_request("fake_endpoint", []) + await async_w3.manager.coro_request("fake_endpoint", []) + + assert next(counter) == 2 + + +@pytest.mark.asyncio +async def test_async_request_caching_does_not_cache_error_responses(request_mocker): + async_w3 = AsyncWeb3(AsyncBaseProvider()) + async_w3.provider.cache_allowed_requests = True + async_w3.provider.cacheable_requests += (RPCEndpoint("fake_endpoint"),) + + async with request_mocker( + async_w3, + mock_errors={"fake_endpoint": lambda *_: {"message": f"msg-{uuid.uuid4()}"}}, + ): + with pytest.raises(ValueError) as err_a: + await async_w3.manager.coro_request("fake_endpoint", []) + with pytest.raises(ValueError) as err_b: + await async_w3.manager.coro_request("fake_endpoint", []) + + assert str(err_a) != str(err_b) + + +@pytest.mark.asyncio +async def test_async_request_caching_does_not_cache_non_allowlist_endpoints( + async_w3, +): + result_a = await async_w3.manager.coro_request("not_on_allowlist", []) + result_b = await async_w3.manager.coro_request("not_on_allowlist", []) + assert result_a != result_b + + +@pytest.mark.asyncio +async def test_async_request_caching_does_not_share_state_between_providers( + request_mocker, +): + async_w3_a, async_w3_b, async_w3_c = ( + AsyncWeb3(AsyncBaseProvider()), + AsyncWeb3(AsyncBaseProvider()), + AsyncWeb3(AsyncBaseProvider()), + ) + mock_results_a = {RPCEndpoint("eth_chainId"): 11111} + mock_results_b = {RPCEndpoint("eth_chainId"): 22222} + mock_results_c = {RPCEndpoint("eth_chainId"): 33333} + + async with request_mocker(async_w3_a, mock_results=mock_results_a): + async with request_mocker(async_w3_b, mock_results=mock_results_b): + async with request_mocker(async_w3_c, mock_results=mock_results_c): + result_a = await async_w3_a.manager.coro_request("eth_chainId", []) + result_b = await async_w3_b.manager.coro_request("eth_chainId", []) + result_c = await async_w3_c.manager.coro_request("eth_chainId", []) + + assert result_a == 11111 + assert result_b == 22222 + assert result_c == 33333 diff --git a/tests/core/contracts/test_contract_example.py b/tests/core/contracts/test_contract_example.py index 7ed27390e9..738b365e4f 100644 --- a/tests/core/contracts/test_contract_example.py +++ b/tests/core/contracts/test_contract_example.py @@ -5,12 +5,10 @@ import pytest_asyncio from web3 import ( + AsyncWeb3, EthereumTesterProvider, Web3, ) -from web3.eth import ( - AsyncEth, -) from web3.providers.eth_tester.main import ( AsyncEthereumTesterProvider, ) @@ -121,10 +119,9 @@ def async_eth_tester(): @pytest_asyncio.fixture() async def async_w3(): - provider = AsyncEthereumTesterProvider() - w3 = Web3(provider, modules={"eth": [AsyncEth]}, middlewares=provider.middlewares) - w3.eth.default_account = await w3.eth.coinbase - return w3 + async_w3 = AsyncWeb3(AsyncEthereumTesterProvider()) + async_w3.eth.default_account = await async_w3.eth.coinbase + return async_w3 @pytest_asyncio.fixture() diff --git a/tests/core/eth-module/test_block_api.py b/tests/core/eth-module/test_block_api.py index b9abae1329..2e4edf2808 100644 --- a/tests/core/eth-module/test_block_api.py +++ b/tests/core/eth-module/test_block_api.py @@ -7,13 +7,6 @@ HexBytes, ) -from web3._utils.rpc_abi import ( - RPC, -) -from web3.middleware import ( - construct_result_generator_middleware, -) - @pytest.fixture(autouse=True) def wait_for_first_block(w3, wait_for_block): @@ -26,7 +19,7 @@ def test_uses_default_block(w3, extra_accounts, wait_for_transaction): assert w3.eth.default_block == w3.eth.block_number -def test_get_block_formatters_with_null_values(w3): +def test_get_block_formatters_with_null_values(w3, request_mocker): null_values_block = { "baseFeePerGas": None, "extraData": None, @@ -51,19 +44,12 @@ def test_get_block_formatters_with_null_values(w3): "withdrawalsRoot": None, "withdrawals": [], } - result_middleware = construct_result_generator_middleware( - { - RPC.eth_getBlockByNumber: lambda *_: null_values_block, - } - ) - - w3.middleware_onion.inject(result_middleware, "result_middleware", layer=0) - - received_block = w3.eth.get_block("pending") + with request_mocker(w3, mock_results={"eth_getBlockByNumber": null_values_block}): + received_block = w3.eth.get_block("pending") assert received_block == null_values_block -def test_get_block_formatters_with_pre_formatted_values(w3): +def test_get_block_formatters_with_pre_formatted_values(w3, request_mocker): unformatted_values_block = { "baseFeePerGas": "0x3b9aca00", "extraData": "0x", @@ -116,15 +102,11 @@ def test_get_block_formatters_with_pre_formatted_values(w3): }, ], } - result_middleware = construct_result_generator_middleware( - { - RPC.eth_getBlockByNumber: lambda *_: unformatted_values_block, - } - ) - - w3.middleware_onion.inject(result_middleware, "result_middleware", layer=0) - received_block = w3.eth.get_block("pending") + with request_mocker( + w3, mock_results={"eth_getBlockByNumber": unformatted_values_block} + ): + received_block = w3.eth.get_block("pending") assert received_block == { "baseFeePerGas": int(unformatted_values_block["baseFeePerGas"], 16), diff --git a/tests/core/eth-module/test_poa.py b/tests/core/eth-module/test_poa.py index 45884066cb..6dfc678fc5 100644 --- a/tests/core/eth-module/test_poa.py +++ b/tests/core/eth-module/test_poa.py @@ -5,54 +5,42 @@ ExtraDataLengthError, ) from web3.middleware import ( - construct_fixture_middleware, - geth_poa_middleware, + extradata_to_poa_middleware, ) # In the spec, a block with extra data longer than 32 bytes is invalid -def test_long_extra_data(w3): - return_block_with_long_extra_data = construct_fixture_middleware( - { - "eth_getBlockByNumber": {"extraData": "0x" + "ff" * 33}, - } - ) - w3.middleware_onion.inject(return_block_with_long_extra_data, layer=0) - with pytest.raises(ExtraDataLengthError): - w3.eth.get_block("latest") - - -def test_full_extra_data(w3): - return_block_with_long_extra_data = construct_fixture_middleware( - { - "eth_getBlockByNumber": {"extraData": "0x" + "ff" * 32}, - } - ) - w3.middleware_onion.inject(return_block_with_long_extra_data, layer=0) - block = w3.eth.get_block("latest") - assert block.extraData == b"\xff" * 32 - - -def test_geth_proof_of_authority(w3): - return_block_with_long_extra_data = construct_fixture_middleware( - { +def test_long_extra_data(w3, request_mocker): + with request_mocker( + w3, mock_results={"eth_getBlockByNumber": {"extraData": "0x" + "ff" * 33}} + ): + with pytest.raises(ExtraDataLengthError): + w3.eth.get_block("latest") + + +def test_full_extra_data(w3, request_mocker): + with request_mocker( + w3, mock_results={"eth_getBlockByNumber": {"extraData": "0x" + "ff" * 32}} + ): + block = w3.eth.get_block("latest") + assert block.extraData == b"\xff" * 32 + + +def test_extradata_to_poa_middleware(w3, request_mocker): + w3.middleware_onion.inject(extradata_to_poa_middleware, layer=0) + + with request_mocker( + w3, + mock_results={ "eth_getBlockByNumber": {"extraData": "0x" + "ff" * 33}, - } - ) - w3.middleware_onion.inject(geth_poa_middleware, layer=0) - w3.middleware_onion.inject(return_block_with_long_extra_data, layer=0) - block = w3.eth.get_block("latest") - assert "extraData" not in block - assert block.proofOfAuthorityData == b"\xff" * 33 - - -def test_returns_none_response(w3): - return_none_response = construct_fixture_middleware( - { - "eth_getBlockByNumber": None, - } - ) - w3.middleware_onion.inject(geth_poa_middleware, layer=0) - w3.middleware_onion.inject(return_none_response, layer=0) - with pytest.raises(BlockNotFound): - w3.eth.get_block(100000000000) + }, + ): + block = w3.eth.get_block("latest") + assert "extraData" not in block + assert block.proofOfAuthorityData == b"\xff" * 33 + + +def test_returns_none_response(w3, request_mocker): + with request_mocker(w3, mock_results={"eth_getBlockByNumber": None}): + with pytest.raises(BlockNotFound): + w3.eth.get_block("latest") diff --git a/tests/core/eth-module/test_transactions.py b/tests/core/eth-module/test_transactions.py index ce09c471d0..7484fd29d2 100644 --- a/tests/core/eth-module/test_transactions.py +++ b/tests/core/eth-module/test_transactions.py @@ -1,3 +1,5 @@ +import collections +import itertools import pytest from eth_utils import ( @@ -23,12 +25,6 @@ TransactionNotFound, Web3ValidationError, ) -from web3.middleware import ( - construct_result_generator_middleware, -) -from web3.middleware.simulate_unmined_transaction import ( - unmined_receipt_simulator_middleware, -) RECEIPT_TIMEOUT = 0.2 @@ -176,8 +172,9 @@ def test_passing_string_to_to_hex(w3): w3.eth.wait_for_transaction_receipt(transaction_hash, timeout=RECEIPT_TIMEOUT) -def test_unmined_transaction_wait_for_receipt(w3): - w3.middleware_onion.add(unmined_receipt_simulator_middleware) +def test_unmined_transaction_wait_for_receipt(w3, request_mocker): + receipt_counters = collections.defaultdict(itertools.count) + txn_hash = w3.eth.send_transaction( { "from": w3.eth.coinbase, @@ -185,15 +182,25 @@ def test_unmined_transaction_wait_for_receipt(w3): "value": 123457, } ) - with pytest.raises(TransactionNotFound): - w3.eth.get_transaction_receipt(txn_hash) + unmocked_make_request = w3.provider.make_request + + with request_mocker( + w3, + mock_results={ + RPC.eth_getTransactionReceipt: lambda method, params: None + if next(receipt_counters[params[0]]) < 5 + else unmocked_make_request(method, params)["result"] + }, + ): + with pytest.raises(TransactionNotFound): + w3.eth.get_transaction_receipt(txn_hash) - txn_receipt = w3.eth.wait_for_transaction_receipt(txn_hash) - assert txn_receipt["transactionHash"] == txn_hash - assert txn_receipt["blockHash"] is not None + txn_receipt = w3.eth.wait_for_transaction_receipt(txn_hash) + assert txn_receipt["transactionHash"] == txn_hash + assert txn_receipt["blockHash"] is not None -def test_get_transaction_formatters(w3): +def test_get_transaction_formatters(w3, request_mocker): non_checksummed_addr = "0xB2930B35844A230F00E51431ACAE96FE543A0347" # all uppercase unformatted_transaction = { "blockHash": ( @@ -236,15 +243,11 @@ def test_get_transaction_formatters(w3): "data": "0x5b34b966", } - result_middleware = construct_result_generator_middleware( - { - RPC.eth_getTransactionByHash: lambda *_: unformatted_transaction, - } - ) - w3.middleware_onion.inject(result_middleware, "result_middleware", layer=0) - - # test against eth_getTransactionByHash - received_tx = w3.eth.get_transaction("") + with request_mocker( + w3, mock_results={RPC.eth_getTransactionByHash: unformatted_transaction} + ): + # test against eth_getTransactionByHash + received_tx = w3.eth.get_transaction("") checksummed_addr = to_checksum_address(non_checksummed_addr) assert non_checksummed_addr != checksummed_addr @@ -299,4 +302,3 @@ def test_get_transaction_formatters(w3): ) assert received_tx == expected - w3.middleware_onion.remove("result_middleware") diff --git a/tests/core/filtering/conftest.py b/tests/core/filtering/conftest.py index 331381be2a..25a7a1bec5 100644 --- a/tests/core/filtering/conftest.py +++ b/tests/core/filtering/conftest.py @@ -67,7 +67,7 @@ def create_filter(request): @pytest.fixture( scope="function", params=[True, False], - ids=["async_local_filter_middleware", "node_based_filter"], + ids=["local_filter_middleware", "node_based_filter"], ) def async_w3(request): return _async_w3_fixture_logic(request) diff --git a/tests/core/filtering/utils.py b/tests/core/filtering/utils.py index 714505966a..d7b0f8cbe0 100644 --- a/tests/core/filtering/utils.py +++ b/tests/core/filtering/utils.py @@ -2,11 +2,7 @@ AsyncWeb3, Web3, ) -from web3.eth import ( - AsyncEth, -) from web3.middleware import ( - async_local_filter_middleware, local_filter_middleware, ) from web3.providers.eth_tester import ( @@ -49,10 +45,10 @@ def _emitter_fixture_logic( def _async_w3_fixture_logic(request): use_filter_middleware = request.param provider = AsyncEthereumTesterProvider() - async_w3 = AsyncWeb3(provider, modules={"eth": [AsyncEth]}, middlewares=[]) + async_w3 = AsyncWeb3(provider) if use_filter_middleware: - async_w3.middleware_onion.add(async_local_filter_middleware) + async_w3.middleware_onion.add(local_filter_middleware) return async_w3 diff --git a/tests/core/gas-strategies/test_time_based_gas_price_strategy.py b/tests/core/gas-strategies/test_time_based_gas_price_strategy.py index f0f159799a..1b46f42039 100644 --- a/tests/core/gas-strategies/test_time_based_gas_price_strategy.py +++ b/tests/core/gas-strategies/test_time_based_gas_price_strategy.py @@ -10,9 +10,6 @@ from web3.gas_strategies.time_based import ( construct_time_based_gas_price_strategy, ) -from web3.middleware import ( - construct_result_generator_middleware, -) from web3.providers.base import ( BaseProvider, ) @@ -150,25 +147,22 @@ def _get_block_by_something(method, params): (dict(max_wait_seconds=80, sample_size=5, probability=50, weighted=True), 11), ), ) -def test_time_based_gas_price_strategy(strategy_params, expected): - fixture_middleware = construct_result_generator_middleware( - { - "eth_getBlockByHash": _get_block_by_something, - "eth_getBlockByNumber": _get_block_by_something, - } - ) - - w3 = Web3( - provider=BaseProvider(), - middlewares=[fixture_middleware], - ) +def test_time_based_gas_price_strategy(strategy_params, expected, request_mocker): + w3 = Web3(provider=BaseProvider()) time_based_gas_price_strategy = construct_time_based_gas_price_strategy( **strategy_params, ) w3.eth.set_gas_price_strategy(time_based_gas_price_strategy) - actual = w3.eth.generate_gas_price() - assert actual == expected + with request_mocker( + w3, + mock_results={ + "eth_getBlockByHash": _get_block_by_something, + "eth_getBlockByNumber": _get_block_by_something, + }, + ): + actual = w3.eth.generate_gas_price() + assert actual == expected def _get_initial_block(method, params): @@ -186,19 +180,8 @@ def _get_gas_price(method, params): return 4321 -def test_time_based_gas_price_strategy_without_transactions(): - fixture_middleware = construct_result_generator_middleware( - { - "eth_getBlockByHash": _get_initial_block, - "eth_getBlockByNumber": _get_initial_block, - "eth_gasPrice": _get_gas_price, - } - ) - - w3 = Web3( - provider=BaseProvider(), - middlewares=[fixture_middleware], - ) +def test_time_based_gas_price_strategy_without_transactions(request_mocker): + w3 = Web3(provider=BaseProvider()) time_based_gas_price_strategy = construct_time_based_gas_price_strategy( max_wait_seconds=80, @@ -207,8 +190,16 @@ def test_time_based_gas_price_strategy_without_transactions(): weighted=True, ) w3.eth.set_gas_price_strategy(time_based_gas_price_strategy) - actual = w3.eth.generate_gas_price() - assert actual == w3.eth.gas_price + with request_mocker( + w3, + mock_results={ + "eth_getBlockByHash": _get_initial_block, + "eth_getBlockByNumber": _get_initial_block, + "eth_gasPrice": _get_gas_price, + }, + ): + actual = w3.eth.generate_gas_price() + assert actual == w3.eth.gas_price @pytest.mark.parametrize( @@ -269,23 +260,20 @@ def test_time_based_gas_price_strategy_without_transactions(): ), ) def test_time_based_gas_price_strategy_zero_sample( - strategy_params_zero, expected_exception_message + strategy_params_zero, expected_exception_message, request_mocker ): with pytest.raises(Web3ValidationError) as excinfo: - fixture_middleware = construct_result_generator_middleware( - { - "eth_getBlockByHash": _get_block_by_something, - "eth_getBlockByNumber": _get_block_by_something, - } - ) - - w3 = Web3( - provider=BaseProvider(), - middlewares=[fixture_middleware], - ) + w3 = Web3(provider=BaseProvider()) time_based_gas_price_strategy_zero = construct_time_based_gas_price_strategy( **strategy_params_zero, ) w3.eth.set_gas_price_strategy(time_based_gas_price_strategy_zero) - w3.eth.generate_gas_price() + with request_mocker( + w3, + mock_results={ + "eth_getBlockByHash": _get_block_by_something, + "eth_getBlockByNumber": _get_block_by_something, + }, + ): + w3.eth.generate_gas_price() assert str(excinfo.value) == expected_exception_message diff --git a/tests/core/manager/conftest.py b/tests/core/manager/conftest.py index 458204dcd5..cc2c86c063 100644 --- a/tests/core/manager/conftest.py +++ b/tests/core/manager/conftest.py @@ -1,6 +1,10 @@ import itertools import pytest +from web3.middleware.base import ( + Web3Middleware, +) + @pytest.fixture def middleware_factory(): @@ -15,15 +19,19 @@ class Wrapper: def __repr__(self): return "middleware-" + key - def __call__(self, make_request, w3): - def middleware_fn(method, params): - params.append(key) - method = "|".join((method, key)) - response = make_request(method, params) - response["result"]["middlewares"].append(key) - return response + def __call__(self, web3): + class Middleware(Web3Middleware): + def _wrap_make_request(self, make_request): + def middleware_fn(method, params): + params.append(key) + method = "|".join((method, key)) + response = make_request(method, params) + response["result"]["middlewares"].append(key) + return response + + return middleware_fn - return middleware_fn + return Middleware(web3) return Wrapper() diff --git a/tests/core/manager/test_default_middlewares.py b/tests/core/manager/test_default_middlewares.py index 64b8a5d1d2..1ad566808d 100644 --- a/tests/core/manager/test_default_middlewares.py +++ b/tests/core/manager/test_default_middlewares.py @@ -2,16 +2,10 @@ RequestManager, ) from web3.middleware import ( - abi_middleware, - async_attrdict_middleware, - async_buffered_gas_estimate_middleware, - async_gas_price_strategy_middleware, - async_name_to_address_middleware, - async_validation_middleware, attrdict_middleware, buffered_gas_estimate_middleware, + ens_name_to_address_middleware, gas_price_strategy_middleware, - name_to_address_middleware, validation_middleware, ) @@ -19,31 +13,12 @@ def test_default_sync_middlewares(w3): expected_middlewares = [ (gas_price_strategy_middleware, "gas_price_strategy"), - (name_to_address_middleware(w3), "name_to_address"), + (ens_name_to_address_middleware, "ens_name_to_address"), (attrdict_middleware, "attrdict"), (validation_middleware, "validation"), - (abi_middleware, "abi"), (buffered_gas_estimate_middleware, "gas_estimate"), ] - default_middlewares = RequestManager.default_middlewares(w3) + default_middlewares = RequestManager.get_default_middlewares() - for x in range(len(default_middlewares)): - assert default_middlewares[x][0].__name__ == expected_middlewares[x][0].__name__ - assert default_middlewares[x][1] == expected_middlewares[x][1] - - -def test_default_async_middlewares(): - expected_middlewares = [ - (async_gas_price_strategy_middleware, "gas_price_strategy"), - (async_name_to_address_middleware, "name_to_address"), - (async_attrdict_middleware, "attrdict"), - (async_validation_middleware, "validation"), - (async_buffered_gas_estimate_middleware, "gas_estimate"), - ] - - default_middlewares = RequestManager.async_default_middlewares() - - for x in range(len(default_middlewares)): - assert default_middlewares[x][0].__name__ == expected_middlewares[x][0].__name__ - assert default_middlewares[x][1] == expected_middlewares[x][1] + assert default_middlewares == expected_middlewares diff --git a/tests/core/manager/test_middleware_can_be_stateful.py b/tests/core/manager/test_middleware_can_be_stateful.py index b666dc25a0..ae8ef8314c 100644 --- a/tests/core/manager/test_middleware_can_be_stateful.py +++ b/tests/core/manager/test_middleware_can_be_stateful.py @@ -1,20 +1,26 @@ from web3.manager import ( RequestManager, ) +from web3.middleware.base import ( + Web3Middleware, +) from web3.providers import ( BaseProvider, ) -def stateful_middleware(make_request, w3): +class StatefulMiddleware(Web3Middleware): state = [] - def middleware(method, params): - state.append((method, params)) - return {"result": state} + def _wrap_make_request(self, make_request): + def middleware(method, params): + self.state.append((method, params)) + return {"result": self.state} + + return middleware + - middleware.state = state - return middleware +stateful_middleware = StatefulMiddleware def test_middleware_holds_state_across_requests(): diff --git a/tests/core/method-class/test_result_formatters.py b/tests/core/method-class/test_result_formatters.py index 324cfc55e5..9bea0eb8d9 100644 --- a/tests/core/method-class/test_result_formatters.py +++ b/tests/core/method-class/test_result_formatters.py @@ -10,9 +10,6 @@ from web3.method import ( Method, ) -from web3.middleware.fixture import ( - construct_result_generator_middleware, -) from web3.module import ( Module, ) @@ -33,11 +30,7 @@ def make_request(method, params): raise NotImplementedError -result_middleware = construct_result_generator_middleware( - { - "method_for_test": lambda m, p: "ok", - } -) +result_for_test = {"method_for_test": "ok"} class ModuleForTest(Module): @@ -48,11 +41,11 @@ class ModuleForTest(Module): def dummy_w3(): w3 = Web3( DummyProvider(), - middlewares=[result_middleware], modules={"module": ModuleForTest}, ) return w3 -def test_result_formatter(dummy_w3): - assert dummy_w3.module.method() == "OKAY" +def test_result_formatter(dummy_w3, request_mocker): + with request_mocker(dummy_w3, mock_results=result_for_test): + assert dummy_w3.module.method() == "OKAY" diff --git a/tests/core/middleware/test_attrdict_middleware.py b/tests/core/middleware/test_attrdict_middleware.py index 21d6de8717..da03afe326 100644 --- a/tests/core/middleware/test_attrdict_middleware.py +++ b/tests/core/middleware/test_attrdict_middleware.py @@ -9,26 +9,18 @@ AttributeDict, ) from web3.middleware import ( - async_attrdict_middleware, - async_construct_result_generator_middleware, attrdict_middleware, - construct_result_generator_middleware, ) from web3.providers.eth_tester import ( AsyncEthereumTesterProvider, ) -from web3.types import ( - RPCEndpoint, -) GENERATED_NESTED_DICT_RESULT = { - "result": { - "a": 1, - "b": { - "b1": 1, - "b2": {"b2a": 1, "b2b": {"b2b1": 1, "b2b2": {"test": "fin"}}}, - }, - } + "a": 1, + "b": { + "b1": 1, + "b2": {"b2a": 1, "b2b": {"b2b1": 1, "b2b2": {"test": "fin"}}}, + }, } @@ -42,19 +34,14 @@ def test_attrdict_middleware_default_for_ethereum_tester_provider(): assert w3.middleware_onion.get("attrdict") == attrdict_middleware -def test_attrdict_middleware_is_recursive(w3): - w3.middleware_onion.inject( - construct_result_generator_middleware( - {RPCEndpoint("fake_endpoint"): lambda *_: GENERATED_NESTED_DICT_RESULT} - ), - "result_gen", - layer=0, - ) - response = w3.manager.request_blocking("fake_endpoint", []) +def test_attrdict_middleware_is_recursive(w3, request_mocker): + with request_mocker( + w3, + mock_results={"fake_endpoint": GENERATED_NESTED_DICT_RESULT}, + ): + result = w3.manager.request_blocking("fake_endpoint", []) - result = response["result"] assert isinstance(result, AttributeDict) - assert response.result == result assert isinstance(result["b"], AttributeDict) assert result.b == result["b"] @@ -65,27 +52,17 @@ def test_attrdict_middleware_is_recursive(w3): assert isinstance(result.b.b2.b2b["b2b2"], AttributeDict) assert result.b.b2.b2b.b2b2 == result.b.b2.b2b["b2b2"] - # cleanup - w3.middleware_onion.remove("result_gen") - -def test_no_attrdict_middleware_does_not_convert_dicts_to_attrdict(): +def test_no_attrdict_middleware_does_not_convert_dicts_to_attrdict(request_mocker): w3 = Web3(EthereumTesterProvider()) - - w3.middleware_onion.inject( - construct_result_generator_middleware( - {RPCEndpoint("fake_endpoint"): lambda *_: GENERATED_NESTED_DICT_RESULT} - ), - "result_gen", - layer=0, - ) - # remove attrdict middleware w3.middleware_onion.remove("attrdict") - response = w3.manager.request_blocking("fake_endpoint", []) - - result = response["result"] + with request_mocker( + w3, + mock_results={"fake_endpoint": GENERATED_NESTED_DICT_RESULT}, + ): + result = w3.manager.request_blocking("fake_endpoint", []) _assert_dict_and_not_attrdict(result) _assert_dict_and_not_attrdict(result["b"]) @@ -100,23 +77,18 @@ def test_no_attrdict_middleware_does_not_convert_dicts_to_attrdict(): @pytest.mark.asyncio async def test_async_attrdict_middleware_default_for_async_ethereum_tester_provider(): async_w3 = AsyncWeb3(AsyncEthereumTesterProvider()) - assert async_w3.middleware_onion.get("attrdict") == async_attrdict_middleware + assert async_w3.middleware_onion.get("attrdict") == attrdict_middleware @pytest.mark.asyncio -async def test_async_attrdict_middleware_is_recursive(async_w3): - async_w3.middleware_onion.inject( - await async_construct_result_generator_middleware( - {RPCEndpoint("fake_endpoint"): lambda *_: GENERATED_NESTED_DICT_RESULT} - ), - "result_gen", - layer=0, - ) - response = await async_w3.manager.coro_request("fake_endpoint", []) - - result = response["result"] +async def test_async_attrdict_middleware_is_recursive(async_w3, request_mocker): + async with request_mocker( + async_w3, + mock_results={"fake_endpoint": GENERATED_NESTED_DICT_RESULT}, + ): + result = await async_w3.manager.coro_request("fake_endpoint", []) + assert isinstance(result, AttributeDict) - assert response.result == result assert isinstance(result["b"], AttributeDict) assert result.b == result["b"] @@ -127,28 +99,21 @@ async def test_async_attrdict_middleware_is_recursive(async_w3): assert isinstance(result.b.b2.b2b["b2b2"], AttributeDict) assert result.b.b2.b2b.b2b2 == result.b.b2.b2b["b2b2"] - # cleanup - async_w3.middleware_onion.remove("result_gen") - @pytest.mark.asyncio -async def test_no_async_attrdict_middleware_does_not_convert_dicts_to_attrdict(): +async def test_no_async_attrdict_middleware_does_not_convert_dicts_to_attrdict( + request_mocker, +): async_w3 = AsyncWeb3(AsyncEthereumTesterProvider()) - async_w3.middleware_onion.inject( - await async_construct_result_generator_middleware( - {RPCEndpoint("fake_endpoint"): lambda *_: GENERATED_NESTED_DICT_RESULT} - ), - "result_gen", - layer=0, - ) - # remove attrdict middleware async_w3.middleware_onion.remove("attrdict") - response = await async_w3.manager.coro_request("fake_endpoint", []) - - result = response["result"] + async with request_mocker( + async_w3, + mock_results={"fake_endpoint": GENERATED_NESTED_DICT_RESULT}, + ): + result = await async_w3.manager.coro_request("fake_endpoint", []) _assert_dict_and_not_attrdict(result) _assert_dict_and_not_attrdict(result["b"]) diff --git a/tests/core/middleware/test_eth_tester_middleware.py b/tests/core/middleware/test_eth_tester_middleware.py index 6797a51fe6..7d8d085e71 100644 --- a/tests/core/middleware/test_eth_tester_middleware.py +++ b/tests/core/middleware/test_eth_tester_middleware.py @@ -4,7 +4,6 @@ ) from web3.providers.eth_tester.middleware import ( - async_default_transaction_fields_middleware, default_transaction_fields_middleware, ) from web3.types import ( @@ -93,9 +92,10 @@ def mock_request(_method, params): mock_w3.eth.accounts = w3_accounts mock_w3.eth.coinbase = w3_coinbase - middleware = default_transaction_fields_middleware(mock_request, mock_w3) + middleware = default_transaction_fields_middleware(mock_w3) base_params = {"chainId": 5} - filled_transaction = middleware(method, [base_params]) + inner = middleware._wrap_make_request(mock_request) + filled_transaction = inner(method, [base_params]) filled_params = filled_transaction[0] @@ -178,11 +178,10 @@ async def mock_async_coinbase(): mock_w3.eth.accounts = mock_async_accounts() mock_w3.eth.coinbase = mock_async_coinbase() - middleware = await async_default_transaction_fields_middleware( - mock_request, mock_w3 - ) + middleware = default_transaction_fields_middleware(mock_w3) base_params = {"chainId": 5} - filled_transaction = await middleware(method, [base_params]) + inner = await middleware._async_wrap_make_request(mock_request) + filled_transaction = await inner(method, [base_params]) filled_params = filled_transaction[0] assert ("from" in filled_params.keys()) == from_field_added diff --git a/tests/core/middleware/test_filter_middleware.py b/tests/core/middleware/test_filter_middleware.py index a65ea155f0..e24e249daf 100644 --- a/tests/core/middleware/test_filter_middleware.py +++ b/tests/core/middleware/test_filter_middleware.py @@ -6,20 +6,14 @@ import pytest_asyncio from web3 import ( + AsyncWeb3, Web3, ) from web3.datastructures import ( AttributeDict, ) -from web3.eth import ( - AsyncEth, -) from web3.middleware import ( - async_attrdict_middleware, - async_construct_result_generator_middleware, - async_local_filter_middleware, attrdict_middleware, - construct_result_generator_middleware, local_filter_middleware, ) from web3.middleware.filter import ( @@ -88,28 +82,20 @@ def iterator(): @pytest.fixture(scope="function") -def result_generator_middleware(iter_block_number): - return construct_result_generator_middleware( - { +def w3(request_mocker, iter_block_number): + w3_base = Web3(provider=DummyProvider(), middlewares=[]) + w3_base.middleware_onion.add(attrdict_middleware) + w3_base.middleware_onion.add(local_filter_middleware) + with request_mocker( + w3_base, + mock_results={ "eth_getLogs": lambda *_: FILTER_LOG, "eth_getBlockByNumber": lambda *_: {"hash": BLOCK_HASH}, "net_version": lambda *_: 1, "eth_blockNumber": lambda *_: next(iter_block_number), - } - ) - - -@pytest.fixture(scope="function") -def w3_base(): - return Web3(provider=DummyProvider(), middlewares=[]) - - -@pytest.fixture(scope="function") -def w3(w3_base, result_generator_middleware): - w3_base.middleware_onion.add(result_generator_middleware) - w3_base.middleware_onion.add(attrdict_middleware) - w3_base.middleware_onion.add(local_filter_middleware) - return w3_base + }, + ): + yield w3_base @pytest.mark.parametrize( @@ -266,30 +252,21 @@ async def make_request(self, method, params): @pytest_asyncio.fixture(scope="function") -async def async_result_generator_middleware(iter_block_number): - return await async_construct_result_generator_middleware( - { +async def async_w3(request_mocker, iter_block_number): + async_w3_base = AsyncWeb3(provider=AsyncDummyProvider(), middlewares=[]) + async_w3_base.middleware_onion.add(attrdict_middleware) + async_w3_base.middleware_onion.add(local_filter_middleware) + + async with request_mocker( + async_w3_base, + mock_results={ "eth_getLogs": lambda *_: FILTER_LOG, "eth_getBlockByNumber": lambda *_: {"hash": BLOCK_HASH}, "net_version": lambda *_: 1, "eth_blockNumber": lambda *_: next(iter_block_number), - } - ) - - -@pytest.fixture(scope="function") -def async_w3_base(): - return Web3( - provider=AsyncDummyProvider(), modules={"eth": (AsyncEth)}, middlewares=[] - ) - - -@pytest.fixture(scope="function") -def async_w3(async_w3_base, async_result_generator_middleware): - async_w3_base.middleware_onion.add(async_result_generator_middleware) - async_w3_base.middleware_onion.add(async_attrdict_middleware) - async_w3_base.middleware_onion.add(async_local_filter_middleware) - return async_w3_base + }, + ): + yield async_w3_base @pytest.mark.parametrize( diff --git a/tests/core/middleware/test_fixture_middleware.py b/tests/core/middleware/test_fixture_middleware.py deleted file mode 100644 index 0c73baa469..0000000000 --- a/tests/core/middleware/test_fixture_middleware.py +++ /dev/null @@ -1,96 +0,0 @@ -import pytest - -from web3 import ( - Web3, -) -from web3.middleware import ( - construct_error_generator_middleware, - construct_fixture_middleware, - construct_result_generator_middleware, -) -from web3.providers.base import ( - BaseProvider, -) - - -class DummyProvider(BaseProvider): - def make_request(self, method, params): - raise NotImplementedError(f"Cannot make request for {method}:{params}") - - -@pytest.fixture -def w3(): - return Web3(provider=DummyProvider(), middlewares=[]) - - -@pytest.mark.parametrize( - "method,expected", - ( - ("test_endpoint", "value-a"), - ("not_implemented", NotImplementedError), - ), -) -def test_fixture_middleware(w3, method, expected): - w3.middleware_onion.add(construct_fixture_middleware({"test_endpoint": "value-a"})) - - if isinstance(expected, type) and issubclass(expected, Exception): - with pytest.raises(expected): - w3.manager.request_blocking(method, []) - else: - actual = w3.manager.request_blocking(method, []) - assert actual == expected - - -@pytest.mark.parametrize( - "method,expected", - ( - ("test_endpoint", "value-a"), - ("not_implemented", NotImplementedError), - ), -) -def test_result_middleware(w3, method, expected): - def _callback(method, params): - return params[0] - - w3.middleware_onion.add( - construct_result_generator_middleware( - { - "test_endpoint": _callback, - } - ) - ) - - if isinstance(expected, type) and issubclass(expected, Exception): - with pytest.raises(expected): - w3.manager.request_blocking(method, [expected]) - else: - actual = w3.manager.request_blocking(method, [expected]) - assert actual == expected - - -@pytest.mark.parametrize( - "method,expected", - ( - ("test_endpoint", "value-a"), - ("not_implemented", NotImplementedError), - ), -) -def test_error_middleware(w3, method, expected): - def _callback(method, params): - return params[0] - - w3.middleware_onion.add( - construct_error_generator_middleware( - { - "test_endpoint": _callback, - } - ) - ) - - if isinstance(expected, type) and issubclass(expected, Exception): - with pytest.raises(expected): - w3.manager.request_blocking(method, [expected]) - else: - with pytest.raises(ValueError) as err: - w3.manager.request_blocking(method, [expected]) - assert expected in str(err) diff --git a/tests/core/middleware/test_formatting_middleware.py b/tests/core/middleware/test_formatting_middleware.py index 8f7f9ca807..969f2bfb46 100644 --- a/tests/core/middleware/test_formatting_middleware.py +++ b/tests/core/middleware/test_formatting_middleware.py @@ -7,9 +7,7 @@ Web3, ) from web3.middleware import ( - construct_error_generator_middleware, construct_formatting_middleware, - construct_result_generator_middleware, ) from web3.providers.base import ( BaseProvider, @@ -29,20 +27,12 @@ def w3(): return Web3(provider=DummyProvider(), middlewares=[]) -def test_formatting_middleware(w3): +def test_formatting_middleware(w3, request_mocker): # No formatters by default - w3.middleware_onion.add(construct_formatting_middleware()) - w3.middleware_onion.add( - construct_result_generator_middleware( - { - "test_endpoint": lambda method, params: "done", - } - ) - ) - expected = "done" - actual = w3.manager.request_blocking("test_endpoint", []) - assert actual == expected + with request_mocker(w3, mock_results={"test_endpoint": "done"}): + actual = w3.manager.request_blocking(RPCEndpoint("test_endpoint"), []) + assert actual == expected def test_formatting_middleware_no_method(w3): @@ -53,74 +43,58 @@ def test_formatting_middleware_no_method(w3): w3.manager.request_blocking("test_endpoint", []) -def test_formatting_middleware_request_formatters(w3): +def test_formatting_middleware_request_formatters(w3, request_mocker): callable_mock = Mock() - w3.middleware_onion.add( - construct_result_generator_middleware( - {RPCEndpoint("test_endpoint"): lambda method, params: "done"} - ) - ) - w3.middleware_onion.add( construct_formatting_middleware( - request_formatters={RPCEndpoint("test_endpoint"): callable_mock} + request_formatters={"test_endpoint": callable_mock} ) ) expected = "done" - actual = w3.manager.request_blocking("test_endpoint", ["param1"]) + with request_mocker(w3, mock_results={"test_endpoint": "done"}): + actual = w3.manager.request_blocking("test_endpoint", ["param1"]) callable_mock.assert_called_once_with(["param1"]) assert actual == expected -def test_formatting_middleware_result_formatters(w3): - w3.middleware_onion.add( - construct_result_generator_middleware( - {RPCEndpoint("test_endpoint"): lambda method, params: "done"} - ) - ) +def test_formatting_middleware_result_formatters(w3, request_mocker): w3.middleware_onion.add( construct_formatting_middleware( - result_formatters={RPCEndpoint("test_endpoint"): lambda x: f"STATUS:{x}"} + result_formatters={"test_endpoint": lambda x: f"STATUS: {x}"} ) ) - expected = "STATUS:done" - actual = w3.manager.request_blocking("test_endpoint", []) + expected = "STATUS: done" + with request_mocker(w3, mock_results={"test_endpoint": "done"}): + actual = w3.manager.request_blocking("test_endpoint", []) + assert actual == expected -def test_formatting_middleware_result_formatters_for_none(w3): - w3.middleware_onion.add( - construct_result_generator_middleware( - {RPCEndpoint("test_endpoint"): lambda method, params: None} - ) - ) +def test_formatting_middleware_result_formatters_for_none(w3, request_mocker): w3.middleware_onion.add( construct_formatting_middleware( - result_formatters={RPCEndpoint("test_endpoint"): lambda x: hex(x)} + result_formatters={"test_endpoint": lambda x: hex(x)} ) ) expected = None - actual = w3.manager.request_blocking("test_endpoint", []) + with request_mocker(w3, mock_results={"test_endpoint": expected}): + actual = w3.manager.request_blocking("test_endpoint", []) assert actual == expected -def test_formatting_middleware_error_formatters(w3): - w3.middleware_onion.add( - construct_error_generator_middleware( - {RPCEndpoint("test_endpoint"): lambda method, params: "error"} - ) - ) +def test_formatting_middleware_error_formatters(w3, request_mocker): w3.middleware_onion.add( construct_formatting_middleware( - result_formatters={RPCEndpoint("test_endpoint"): lambda x: f"STATUS:{x}"} + result_formatters={"test_endpoint": lambda x: f"STATUS: {x}"} ) ) expected = "error" - with pytest.raises(ValueError) as err: - w3.manager.request_blocking("test_endpoint", []) - assert str(err.value) == expected + with request_mocker(w3, mock_errors={"test_endpoint": {"message": "error"}}): + with pytest.raises(ValueError) as err: + w3.manager.request_blocking("test_endpoint", []) + assert str(err.value) == expected diff --git a/tests/core/middleware/test_gas_price_strategy.py b/tests/core/middleware/test_gas_price_strategy.py index c5193d41c2..8f62e4a270 100644 --- a/tests/core/middleware/test_gas_price_strategy.py +++ b/tests/core/middleware/test_gas_price_strategy.py @@ -3,91 +3,67 @@ Mock, ) +from toolz import ( + merge, +) + from web3.middleware import ( gas_price_strategy_middleware, ) @pytest.fixture -def the_gas_price_strategy_middleware(w3): - make_request, w3 = Mock(), Mock() - initialized = gas_price_strategy_middleware(make_request, w3) - initialized.w3 = w3 - initialized.make_request = make_request +def the_gas_price_strategy_middleware(): + w3 = Mock() + initialized = gas_price_strategy_middleware(w3) return initialized def test_gas_price_generated(the_gas_price_strategy_middleware): - the_gas_price_strategy_middleware.w3.eth.generate_gas_price.return_value = 5 - method = "eth_sendTransaction" - params = ( - { - "to": "0x0", - "value": 1, - }, - ) - the_gas_price_strategy_middleware(method, params) - the_gas_price_strategy_middleware.w3.eth.generate_gas_price.assert_called_once_with( - { - "to": "0x0", - "value": 1, - } - ) - the_gas_price_strategy_middleware.make_request.assert_called_once_with( - method, - ( - { - "to": "0x0", - "value": 1, - "gasPrice": "0x5", - }, - ), + w3 = the_gas_price_strategy_middleware._w3 + w3.eth.generate_gas_price.return_value = 5 + + make_request = Mock() + inner = the_gas_price_strategy_middleware._wrap_make_request(make_request) + method, dict_param = "eth_sendTransaction", {"to": "0x0", "value": 1} + inner(method, (dict_param,)) + + w3.eth.generate_gas_price.assert_called_once_with(dict_param) + make_request.assert_called_once_with( + method, (merge(dict_param, {"gasPrice": "0x5"}),) ) def test_gas_price_not_overridden(the_gas_price_strategy_middleware): - the_gas_price_strategy_middleware.w3.eth.generate_gas_price.return_value = 5 - method = "eth_sendTransaction" - params = ( - { - "to": "0x0", - "value": 1, - "gasPrice": 10, - }, - ) - the_gas_price_strategy_middleware(method, params) - the_gas_price_strategy_middleware.make_request.assert_called_once_with( - method, - ( - { - "to": "0x0", - "value": 1, - "gasPrice": 10, - }, - ), - ) + the_gas_price_strategy_middleware._w3.eth.generate_gas_price.return_value = 5 + + make_request = Mock() + inner = the_gas_price_strategy_middleware._wrap_make_request(make_request) + method, params = "eth_sendTransaction", ({"to": "0x0", "value": 1, "gasPrice": 10},) + inner(method, params) + + make_request.assert_called_once_with(method, params) def test_gas_price_not_set_without_gas_price_strategy( the_gas_price_strategy_middleware, ): - the_gas_price_strategy_middleware.w3.eth.generate_gas_price.return_value = None - method = "eth_sendTransaction" - params = ( - { - "to": "0x0", - "value": 1, - }, - ) - the_gas_price_strategy_middleware(method, params) - the_gas_price_strategy_middleware.make_request.assert_called_once_with( - method, params - ) + the_gas_price_strategy_middleware._w3.eth.generate_gas_price.return_value = None + + make_request = Mock() + inner = the_gas_price_strategy_middleware._wrap_make_request(make_request) + method, params = "eth_sendTransaction", ({"to": "0x0", "value": 1},) + inner(method, params) + + make_request.assert_called_once_with(method, params) def test_not_generate_gas_price_when_not_send_transaction_rpc( the_gas_price_strategy_middleware, ): - the_gas_price_strategy_middleware.w3.getGasPriceStrategy = Mock() - the_gas_price_strategy_middleware("eth_getBalance", []) - the_gas_price_strategy_middleware.w3.getGasPriceStrategy.assert_not_called() + the_gas_price_strategy_middleware._w3.get_gas_price_strategy = Mock() + + inner = the_gas_price_strategy_middleware._wrap_make_request(Mock()) + inner("eth_getBalance", []) + + the_gas_price_strategy_middleware._w3.get_gas_price_strategy.assert_not_called() diff --git a/tests/core/middleware/test_http_request_retry.py b/tests/core/middleware/test_http_request_retry.py deleted file mode 100644 index 02060b5d85..0000000000 --- a/tests/core/middleware/test_http_request_retry.py +++ /dev/null @@ -1,208 +0,0 @@ -import pytest -from unittest.mock import ( - Mock, - patch, -) - -import aiohttp -import pytest_asyncio -from requests.exceptions import ( - ConnectionError, - HTTPError, - Timeout, - TooManyRedirects, -) - -import web3 -from web3 import ( - AsyncHTTPProvider, - AsyncWeb3, -) -from web3.middleware.exception_retry_request import ( - async_exception_retry_middleware, - check_if_retry_on_failure, - exception_retry_middleware, -) -from web3.providers import ( - HTTPProvider, - IPCProvider, -) - - -@pytest.fixture -def exception_retry_request_setup(): - w3 = Mock() - provider = HTTPProvider() - errors = (ConnectionError, HTTPError, Timeout, TooManyRedirects) - setup = exception_retry_middleware(provider.make_request, w3, errors, 5) - setup.w3 = w3 - return setup - - -def test_check_if_retry_on_failure_false(): - methods = [ - "eth_sendTransaction", - "personal_signAndSendTransaction", - "personal_sendTransaction", - ] - - for method in methods: - assert not check_if_retry_on_failure(method) - - -def test_check_if_retry_on_failure_true(): - method = "eth_getBalance" - assert check_if_retry_on_failure(method) - - -@patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) -def test_check_send_transaction_called_once( - make_post_request_mock, exception_retry_request_setup -): - method = "eth_sendTransaction" - params = [ - { - "to": "0x0", - "value": 1, - } - ] - - with pytest.raises(ConnectionError): - exception_retry_request_setup(method, params) - assert make_post_request_mock.call_count == 1 - - -@patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) -def test_valid_method_retried(make_post_request_mock, exception_retry_request_setup): - method = "eth_getBalance" - params = [] - - with pytest.raises(ConnectionError): - exception_retry_request_setup(method, params) - assert make_post_request_mock.call_count == 5 - - -def test_is_strictly_default_http_middleware(): - w3 = HTTPProvider() - assert "http_retry_request" in w3.middlewares - - w3 = IPCProvider() - assert "http_retry_request" not in w3.middlewares - - -@patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) -def test_check_with_all_middlewares(make_post_request_mock): - provider = HTTPProvider() - w3 = web3.Web3(provider) - with pytest.raises(ConnectionError): - w3.eth.block_number - assert make_post_request_mock.call_count == 5 - - -@patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) -def test_exception_retry_middleware_with_allow_list_kwarg( - make_post_request_mock, exception_retry_request_setup -): - w3 = Mock() - provider = HTTPProvider() - errors = (ConnectionError, HTTPError, Timeout, TooManyRedirects) - setup = exception_retry_middleware( - provider.make_request, w3, errors, 5, allow_list=["test_userProvidedMethod"] - ) - setup.w3 = w3 - with pytest.raises(ConnectionError): - setup("test_userProvidedMethod", []) - assert make_post_request_mock.call_count == 5 - - make_post_request_mock.reset_mock() - with pytest.raises(ConnectionError): - setup("eth_getBalance", []) - assert make_post_request_mock.call_count == 1 - - -# -- async -- # - - -ASYNC_TEST_RETRY_COUNT = 3 - - -@pytest_asyncio.fixture -async def async_exception_retry_request_setup(): - w3 = Mock() - provider = AsyncHTTPProvider() - setup = await async_exception_retry_middleware( - provider.make_request, - w3, - (TimeoutError, aiohttp.ClientError), - ASYNC_TEST_RETRY_COUNT, - 0.1, - ) - setup.w3 = w3 - return setup - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "error", - ( - TimeoutError, - aiohttp.ClientError, # general base class for all aiohttp errors - # ClientError subclasses - aiohttp.ClientConnectionError, - aiohttp.ServerTimeoutError, - aiohttp.ClientOSError, - ), -) -async def test_async_check_retry_middleware(error, async_exception_retry_request_setup): - with patch( - "web3.providers.async_rpc.async_make_post_request" - ) as async_make_post_request_mock: - async_make_post_request_mock.side_effect = error - - with pytest.raises(error): - await async_exception_retry_request_setup("eth_getBalance", []) - assert async_make_post_request_mock.call_count == ASYNC_TEST_RETRY_COUNT - - -@pytest.mark.asyncio -async def test_async_check_without_retry_middleware(): - with patch( - "web3.providers.async_rpc.async_make_post_request" - ) as async_make_post_request_mock: - async_make_post_request_mock.side_effect = TimeoutError - provider = AsyncHTTPProvider() - w3 = AsyncWeb3(provider) - w3.provider._middlewares = () - - with pytest.raises(TimeoutError): - await w3.eth.block_number - assert async_make_post_request_mock.call_count == 1 - - -@pytest.mark.asyncio -async def test_async_exception_retry_middleware_with_allow_list_kwarg(): - w3 = Mock() - provider = AsyncHTTPProvider() - setup = await async_exception_retry_middleware( - provider.make_request, - w3, - (TimeoutError, aiohttp.ClientError), - retries=ASYNC_TEST_RETRY_COUNT, - backoff_factor=0.1, - allow_list=["test_userProvidedMethod"], - ) - setup.w3 = w3 - - with patch( - "web3.providers.async_rpc.async_make_post_request" - ) as async_make_post_request_mock: - async_make_post_request_mock.side_effect = TimeoutError - - with pytest.raises(TimeoutError): - await setup("test_userProvidedMethod", []) - assert async_make_post_request_mock.call_count == ASYNC_TEST_RETRY_COUNT - - async_make_post_request_mock.reset_mock() - with pytest.raises(TimeoutError): - await setup("eth_getBalance", []) - assert async_make_post_request_mock.call_count == 1 diff --git a/tests/core/middleware/test_latest_block_based_cache_middleware.py b/tests/core/middleware/test_latest_block_based_cache_middleware.py deleted file mode 100644 index 32f88333af..0000000000 --- a/tests/core/middleware/test_latest_block_based_cache_middleware.py +++ /dev/null @@ -1,287 +0,0 @@ -import codecs -import itertools -import pytest -import time -import uuid - -from eth_utils import ( - is_hex, - is_integer, - to_tuple, -) - -from web3 import ( - Web3, -) -from web3._utils.caching import ( - generate_cache_key, -) -from web3._utils.formatters import ( - hex_to_integer, -) -from web3.middleware import ( - construct_error_generator_middleware, - construct_latest_block_based_cache_middleware, - construct_result_generator_middleware, -) -from web3.providers.base import ( - BaseProvider, -) - - -@pytest.fixture -def w3_base(): - return Web3(provider=BaseProvider(), middlewares=[]) - - -def _mk_block(n, timestamp): - return { - "hash": codecs.decode(str(n).zfill(32), "hex"), - "number": n, - "timestamp": timestamp, - } - - -@to_tuple -def generate_block_history(num_mined_blocks=5, block_time=1): - genesis = _mk_block(0, time.time()) - yield genesis - for block_number in range(1, num_mined_blocks + 1): - yield _mk_block( - block_number, - genesis["timestamp"] + 2 * block_number, - ) - - -@pytest.fixture -def construct_block_data_middleware(): - def _construct_block_data_middleware(num_blocks): - blocks = generate_block_history(num_blocks) - _block_info = {"blocks": blocks, "head_block_number": blocks[0]["number"]} - - def _evm_mine(method, params, block_info=_block_info): - num_blocks = params[0] - head_block_number = block_info["head_block_number"] - if head_block_number + num_blocks >= len(block_info["blocks"]): - raise ValueError("no more blocks to mine") - - block_info["head_block_number"] += num_blocks - - def _get_block_by_number(method, params, block_info=_block_info): - block_id = params[0] - blocks = block_info["blocks"] - head_block_number = block_info["head_block_number"] - - if block_id == "latest": - return blocks[head_block_number] - elif block_id == "pending": - if head_block_number + 1 >= len(blocks): - raise ValueError("no pending block") - return blocks[head_block_number + 1] - elif block_id == "earliest": - return blocks[0] - elif is_integer(block_id): - if block_id <= head_block_number: - return blocks[block_id] - else: - return None - elif is_hex(block_id): - block_id = hex_to_integer(block_id) - if block_id <= head_block_number: - return blocks[block_id] - else: - return None - else: - raise TypeError("Invalid type for block_id") - - def _get_block_by_hash(method, params, block_info=_block_info): - block_hash = params[0] - blocks = block_info["blocks"] - head_block_number = block_info["head_block_number"] - - blocks_by_hash = {block["hash"]: block for block in blocks} - try: - block = blocks_by_hash[block_hash] - if block["number"] <= head_block_number: - return block - else: - return None - except KeyError: - return None - - return construct_result_generator_middleware( - { - "eth_getBlockByNumber": _get_block_by_number, - "eth_getBlockByHash": _get_block_by_hash, - "evm_mine": _evm_mine, - } - ) - - return _construct_block_data_middleware - - -@pytest.fixture -def block_data_middleware(construct_block_data_middleware): - return construct_block_data_middleware(5) - - -@pytest.fixture -def result_generator_middleware(): - return construct_result_generator_middleware( - { - "fake_endpoint": lambda *_: str(uuid.uuid4()), - "not_whitelisted": lambda *_: str(uuid.uuid4()), - } - ) - - -@pytest.fixture -def latest_block_based_cache_middleware(): - return construct_latest_block_based_cache_middleware( - cache_class=dict, - average_block_time_sample_size=1, - default_average_block_time=0.1, - rpc_whitelist={"fake_endpoint"}, - ) - - -@pytest.fixture -def w3( - w3_base, - result_generator_middleware, - block_data_middleware, - latest_block_based_cache_middleware, -): - w3_base.middleware_onion.add(block_data_middleware) - w3_base.middleware_onion.add(result_generator_middleware) - w3_base.middleware_onion.add(latest_block_based_cache_middleware) - return w3_base - - -def test_latest_block_based_cache_middleware_pulls_from_cache( - w3_base, block_data_middleware, result_generator_middleware -): - w3 = w3_base - w3.middleware_onion.add(block_data_middleware) - w3.middleware_onion.add(result_generator_middleware) - - current_block_hash = w3.eth.get_block("latest")["hash"] - - def cache_class(): - return { - generate_cache_key((current_block_hash, "fake_endpoint", [1])): { - "result": "value-a" - }, - } - - w3.middleware_onion.add( - construct_latest_block_based_cache_middleware( - cache_class=cache_class, - rpc_whitelist={"fake_endpoint"}, - ) - ) - - assert w3.manager.request_blocking("fake_endpoint", [1]) == "value-a" - - -def test_latest_block_based_cache_middleware_populates_cache(w3): - result = w3.manager.request_blocking("fake_endpoint", []) - - assert w3.manager.request_blocking("fake_endpoint", []) == result - assert w3.manager.request_blocking("fake_endpoint", [1]) != result - - -def test_latest_block_based_cache_middleware_busts_cache(w3, mocker): - result = w3.manager.request_blocking("fake_endpoint", []) - - assert w3.manager.request_blocking("fake_endpoint", []) == result - w3.testing.mine() - - # should still be cached for at least 1 second. This also verifies that - # the middleware caches the latest block based on the block time. - assert w3.manager.request_blocking("fake_endpoint", []) == result - - mocker.patch("time.time", return_value=time.time() + 5) - - assert w3.manager.request_blocking("fake_endpoint", []) != result - - -def test_latest_block_cache_middleware_does_not_cache_bad_responses( - w3_base, block_data_middleware, latest_block_based_cache_middleware -): - counter = itertools.count() - w3 = w3_base - - def result_cb(method, params): - next(counter) - return None - - w3 = w3_base - w3.middleware_onion.add(block_data_middleware) - w3.middleware_onion.add( - construct_result_generator_middleware( - { - "fake_endpoint": result_cb, - } - ) - ) - w3.middleware_onion.add(latest_block_based_cache_middleware) - - w3.manager.request_blocking("fake_endpoint", []) - w3.manager.request_blocking("fake_endpoint", []) - - assert next(counter) == 2 - - -def test_latest_block_cache_middleware_does_not_cache_error_response( - w3_base, block_data_middleware, latest_block_based_cache_middleware -): - counter = itertools.count() - w3 = w3_base - - def error_cb(method, params): - next(counter) - return "the error message" - - w3.middleware_onion.add(block_data_middleware) - w3.middleware_onion.add( - construct_error_generator_middleware( - { - "fake_endpoint": error_cb, - } - ) - ) - w3.middleware_onion.add(latest_block_based_cache_middleware) - - with pytest.raises(ValueError): - w3.manager.request_blocking("fake_endpoint", []) - with pytest.raises(ValueError): - w3.manager.request_blocking("fake_endpoint", []) - - assert next(counter) == 2 - - -def test_latest_block_cache_middleware_does_not_cache_get_latest_block( - w3_base, block_data_middleware, result_generator_middleware -): - w3 = w3_base - w3.middleware_onion.add(block_data_middleware) - w3.middleware_onion.add(result_generator_middleware) - - current_block_hash = w3.eth.get_block("latest")["hash"] - - def cache_class(): - return { - generate_cache_key( - (current_block_hash, "eth_getBlockByNumber", ["latest"]) - ): {"result": "value-a"}, - } - - w3.middleware_onion.add( - construct_latest_block_based_cache_middleware( - cache_class=cache_class, - rpc_whitelist={"eth_getBlockByNumber"}, - ) - ) - - assert w3.manager.request_blocking("eth_getBlockByNumber", ["latest"]) != "value-a" diff --git a/tests/core/middleware/test_name_to_address_middleware.py b/tests/core/middleware/test_name_to_address_middleware.py index d6dc040942..7eee54568d 100644 --- a/tests/core/middleware/test_name_to_address_middleware.py +++ b/tests/core/middleware/test_name_to_address_middleware.py @@ -11,10 +11,7 @@ NameNotFound, ) from web3.middleware import ( - name_to_address_middleware, -) -from web3.middleware.names import ( - async_name_to_address_middleware, + ens_name_to_address_middleware, ) from web3.providers.eth_tester import ( AsyncEthereumTesterProvider, @@ -50,7 +47,8 @@ def ens_addr_account_balance(ens_mapped_address, _w3_setup): @pytest.fixture def w3(_w3_setup, ens_mapped_address): _w3_setup.ens = TempENS({NAME: ens_mapped_address}) - _w3_setup.middleware_onion.add(name_to_address_middleware(_w3_setup)) + ens_name_to_address_middleware._w3 = _w3_setup + _w3_setup.middleware_onion.add(ens_name_to_address_middleware) return _w3_setup @@ -128,7 +126,7 @@ async def async_ens_addr_account_balance(async_ens_mapped_address, _async_w3_set @pytest_asyncio.fixture async def async_w3(_async_w3_setup, async_ens_mapped_address): _async_w3_setup.ens = AsyncTempENS({NAME: async_ens_mapped_address}) - _async_w3_setup.middleware_onion.add(async_name_to_address_middleware) + _async_w3_setup.middleware_onion.add(ens_name_to_address_middleware) return _async_w3_setup diff --git a/tests/core/middleware/test_simple_cache_middleware.py b/tests/core/middleware/test_simple_cache_middleware.py deleted file mode 100644 index 5a89e01496..0000000000 --- a/tests/core/middleware/test_simple_cache_middleware.py +++ /dev/null @@ -1,352 +0,0 @@ -import itertools -import pytest -import threading -import uuid - -from web3 import ( - AsyncWeb3, - Web3, -) -from web3._utils.caching import ( - generate_cache_key, -) -from web3.middleware import ( - async_simple_cache_middleware, - construct_error_generator_middleware, - construct_result_generator_middleware, - construct_simple_cache_middleware, - simple_cache_middleware, -) -from web3.middleware.async_cache import ( - async_construct_simple_cache_middleware, -) -from web3.middleware.fixture import ( - async_construct_error_generator_middleware, - async_construct_result_generator_middleware, -) -from web3.providers.base import ( - BaseProvider, -) -from web3.providers.eth_tester import ( - AsyncEthereumTesterProvider, -) -from web3.types import ( - RPCEndpoint, -) -from web3.utils.caching import ( - SimpleCache, -) - - -@pytest.fixture -def w3_base(): - return Web3(provider=BaseProvider(), middlewares=[]) - - -@pytest.fixture -def result_generator_middleware(): - return construct_result_generator_middleware( - { - RPCEndpoint("fake_endpoint"): lambda *_: str(uuid.uuid4()), - RPCEndpoint("not_whitelisted"): lambda *_: str(uuid.uuid4()), - } - ) - - -@pytest.fixture -def w3(w3_base, result_generator_middleware): - w3_base.middleware_onion.add(result_generator_middleware) - return w3_base - - -def simple_cache_return_value_a(): - _cache = SimpleCache() - _cache.cache( - generate_cache_key(f"{threading.get_ident()}:{('fake_endpoint', [1])}"), - {"result": "value-a"}, - ) - return _cache - - -def test_simple_cache_middleware_pulls_from_cache(w3): - w3.middleware_onion.add( - construct_simple_cache_middleware( - cache=simple_cache_return_value_a(), - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - ) - - assert w3.manager.request_blocking("fake_endpoint", [1]) == "value-a" - - -def test_simple_cache_middleware_populates_cache(w3): - w3.middleware_onion.add( - construct_simple_cache_middleware( - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - ) - - result = w3.manager.request_blocking("fake_endpoint", []) - - assert w3.manager.request_blocking("fake_endpoint", []) == result - assert w3.manager.request_blocking("fake_endpoint", [1]) != result - - -def test_simple_cache_middleware_does_not_cache_none_responses(w3_base): - counter = itertools.count() - w3 = w3_base - - def result_cb(_method, _params): - next(counter) - return None - - w3.middleware_onion.add( - construct_result_generator_middleware( - { - RPCEndpoint("fake_endpoint"): result_cb, - } - ) - ) - - w3.middleware_onion.add( - construct_simple_cache_middleware( - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - ) - - w3.manager.request_blocking("fake_endpoint", []) - w3.manager.request_blocking("fake_endpoint", []) - - assert next(counter) == 2 - - -def test_simple_cache_middleware_does_not_cache_error_responses(w3_base): - w3 = w3_base - w3.middleware_onion.add( - construct_error_generator_middleware( - { - RPCEndpoint("fake_endpoint"): lambda *_: f"msg-{uuid.uuid4()}", - } - ) - ) - - w3.middleware_onion.add( - construct_simple_cache_middleware( - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - ) - - with pytest.raises(ValueError) as err_a: - w3.manager.request_blocking("fake_endpoint", []) - with pytest.raises(ValueError) as err_b: - w3.manager.request_blocking("fake_endpoint", []) - - assert str(err_a) != str(err_b) - - -def test_simple_cache_middleware_does_not_cache_endpoints_not_in_whitelist(w3): - w3.middleware_onion.add( - construct_simple_cache_middleware( - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - ) - - result_a = w3.manager.request_blocking("not_whitelisted", []) - result_b = w3.manager.request_blocking("not_whitelisted", []) - - assert result_a != result_b - - -def test_simple_cache_middleware_does_not_share_state_between_providers(): - result_generator_a = construct_result_generator_middleware( - {RPCEndpoint("eth_chainId"): lambda *_: 11111} - ) - result_generator_b = construct_result_generator_middleware( - {RPCEndpoint("eth_chainId"): lambda *_: 22222} - ) - result_generator_c = construct_result_generator_middleware( - {RPCEndpoint("eth_chainId"): lambda *_: 33333} - ) - - w3_a = Web3(provider=BaseProvider(), middlewares=[result_generator_a]) - w3_b = Web3(provider=BaseProvider(), middlewares=[result_generator_b]) - w3_c = Web3( # instantiate the Web3 instance with the cache middleware - provider=BaseProvider(), - middlewares=[ - result_generator_c, - simple_cache_middleware, - ], - ) - - w3_a.middleware_onion.add(simple_cache_middleware) - w3_b.middleware_onion.add(simple_cache_middleware) - - result_a = w3_a.manager.request_blocking("eth_chainId", []) - result_b = w3_b.manager.request_blocking("eth_chainId", []) - result_c = w3_c.manager.request_blocking("eth_chainId", []) - - assert result_a != result_b != result_c - assert result_a == 11111 - assert result_b == 22222 - assert result_c == 33333 - - -# -- async -- # - - -async def _async_simple_cache_middleware_for_testing(make_request, async_w3): - middleware = await async_construct_simple_cache_middleware( - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - return await middleware(make_request, async_w3) - - -@pytest.fixture -def async_w3(): - return AsyncWeb3( - provider=AsyncEthereumTesterProvider(), - middlewares=[ - (_async_simple_cache_middleware_for_testing, "simple_cache"), - ], - ) - - -@pytest.mark.asyncio -async def test_async_simple_cache_middleware_pulls_from_cache(async_w3): - async def _properly_awaited_middleware(make_request, _async_w3): - middleware = await async_construct_simple_cache_middleware( - cache=simple_cache_return_value_a(), - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - return await middleware(make_request, _async_w3) - - async_w3.middleware_onion.inject( - _properly_awaited_middleware, - layer=0, - ) - - _result = await async_w3.manager.coro_request("fake_endpoint", [1]) - assert _result == "value-a" - - -@pytest.mark.asyncio -async def test_async_simple_cache_middleware_populates_cache(async_w3): - async_w3.middleware_onion.inject( - await async_construct_result_generator_middleware( - { - RPCEndpoint("fake_endpoint"): lambda *_: str(uuid.uuid4()), - } - ), - "result_generator", - layer=0, - ) - - result = await async_w3.manager.coro_request("fake_endpoint", []) - - _empty_params = await async_w3.manager.coro_request("fake_endpoint", []) - _non_empty_params = await async_w3.manager.coro_request("fake_endpoint", [1]) - - assert _empty_params == result - assert _non_empty_params != result - - -@pytest.mark.asyncio -async def test_async_simple_cache_middleware_does_not_cache_none_responses(async_w3): - counter = itertools.count() - - def result_cb(_method, _params): - next(counter) - return None - - async_w3.middleware_onion.inject( - await async_construct_result_generator_middleware( - { - RPCEndpoint("fake_endpoint"): result_cb, - }, - ), - "result_generator", - layer=0, - ) - - await async_w3.manager.coro_request("fake_endpoint", []) - await async_w3.manager.coro_request("fake_endpoint", []) - - assert next(counter) == 2 - - -@pytest.mark.asyncio -async def test_async_simple_cache_middleware_does_not_cache_error_responses(async_w3): - async_w3.middleware_onion.inject( - await async_construct_error_generator_middleware( - { - RPCEndpoint("fake_endpoint"): lambda *_: f"msg-{uuid.uuid4()}", - } - ), - "error_generator", - layer=0, - ) - - with pytest.raises(ValueError) as err_a: - await async_w3.manager.coro_request("fake_endpoint", []) - with pytest.raises(ValueError) as err_b: - await async_w3.manager.coro_request("fake_endpoint", []) - - assert str(err_a) != str(err_b) - - -@pytest.mark.asyncio -async def test_async_simple_cache_middleware_does_not_cache_non_whitelist_endpoints( - async_w3, -): - async_w3.middleware_onion.inject( - await async_construct_result_generator_middleware( - { - RPCEndpoint("not_whitelisted"): lambda *_: str(uuid.uuid4()), - } - ), - layer=0, - ) - - result_a = await async_w3.manager.coro_request("not_whitelisted", []) - result_b = await async_w3.manager.coro_request("not_whitelisted", []) - - assert result_a != result_b - - -@pytest.mark.asyncio -async def test_async_simple_cache_middleware_does_not_share_state_between_providers(): - result_generator_a = await async_construct_result_generator_middleware( - {RPCEndpoint("eth_chainId"): lambda *_: 11111} - ) - result_generator_b = await async_construct_result_generator_middleware( - {RPCEndpoint("eth_chainId"): lambda *_: 22222} - ) - result_generator_c = await async_construct_result_generator_middleware( - {RPCEndpoint("eth_chainId"): lambda *_: 33333} - ) - - w3_a = AsyncWeb3( - provider=AsyncEthereumTesterProvider(), middlewares=[result_generator_a] - ) - w3_b = AsyncWeb3( - provider=AsyncEthereumTesterProvider(), middlewares=[result_generator_b] - ) - w3_c = AsyncWeb3( # instantiate the Web3 instance with the cache middleware - provider=AsyncEthereumTesterProvider(), - middlewares=[ - result_generator_c, - async_simple_cache_middleware, - ], - ) - - w3_a.middleware_onion.add(async_simple_cache_middleware) - w3_b.middleware_onion.add(async_simple_cache_middleware) - - result_a = await w3_a.eth.chain_id - result_b = await w3_b.eth.chain_id - result_c = await w3_c.eth.chain_id - - assert result_a != result_b != result_c - assert result_a == 11111 - assert result_b == 22222 - assert result_c == 33333 diff --git a/tests/core/middleware/test_stalecheck.py b/tests/core/middleware/test_stalecheck.py index a558b74e95..448332d266 100644 --- a/tests/core/middleware/test_stalecheck.py +++ b/tests/core/middleware/test_stalecheck.py @@ -16,7 +16,6 @@ from web3.middleware.stalecheck import ( StaleBlockchain, _is_fresh, - async_make_stalecheck_middleware, ) @@ -32,13 +31,10 @@ def allowable_delay(): @pytest.fixture def request_middleware(allowable_delay): - middleware = make_stalecheck_middleware(allowable_delay) - make_request, web3 = Mock(), Mock() - initialized = middleware(make_request, web3) - # for easier mocking, later: - initialized.web3 = web3 - initialized.make_request = make_request - return initialized + web3 = Mock() + middleware = make_stalecheck_middleware(allowable_delay, web3) + middleware._w3.provider.make_request = Mock() + return middleware def stub_block(timestamp): @@ -70,16 +66,23 @@ def test_is_fresh(now): def test_stalecheck_pass(request_middleware): with patch("web3.middleware.stalecheck._is_fresh", return_value=True): + make_request = Mock() + inner = request_middleware._wrap_make_request(make_request) + method, params = object(), object() - request_middleware(method, params) - request_middleware.make_request.assert_called_once_with(method, params) + inner(method, params) + + make_request.assert_called_once_with(method, params) def test_stalecheck_fail(request_middleware, now): with patch("web3.middleware.stalecheck._is_fresh", return_value=False): - request_middleware.web3.eth.get_block.return_value = stub_block(now) + request_middleware._w3.eth.get_block.return_value = stub_block(now) + + response = object() + inner = request_middleware._wrap_make_request(lambda *_: response) with pytest.raises(StaleBlockchain): - request_middleware("", []) + inner("", []) @pytest.mark.parametrize( @@ -92,8 +95,9 @@ def test_stalecheck_ignores_get_by_block_methods(request_middleware, rpc_method) # This is especially critical for get_block('latest') # which would cause infinite recursion with patch("web3.middleware.stalecheck._is_fresh", side_effect=[False, True]): - request_middleware(rpc_method, []) - assert not request_middleware.web3.eth.get_block.called + inner = request_middleware._wrap_make_request(lambda *_: None) + inner(rpc_method, []) + assert not request_middleware._w3.eth.get_block.called def test_stalecheck_calls_is_fresh_with_empty_cache( @@ -103,8 +107,9 @@ def test_stalecheck_calls_is_fresh_with_empty_cache( "web3.middleware.stalecheck._is_fresh", side_effect=[False, True] ) as fresh_spy: block = object() - request_middleware.web3.eth.get_block.return_value = block - request_middleware("", []) + request_middleware._w3.eth.get_block.return_value = block + inner = request_middleware._wrap_make_request(lambda *_: None) + inner("", []) cache_call, live_call = fresh_spy.call_args_list assert cache_call[0] == (None, allowable_delay) assert live_call[0] == (block, allowable_delay) @@ -115,17 +120,18 @@ def test_stalecheck_adds_block_to_cache(request_middleware, allowable_delay): "web3.middleware.stalecheck._is_fresh", side_effect=[False, True, True] ) as fresh_spy: block = object() - request_middleware.web3.eth.get_block.return_value = block + request_middleware._w3.eth.get_block.return_value = block # cache miss - request_middleware("", []) + inner = request_middleware._wrap_make_request(lambda *_: None) + inner("", []) cache_call, live_call = fresh_spy.call_args_list assert fresh_spy.call_count == 2 assert cache_call == ((None, allowable_delay),) assert live_call == ((block, allowable_delay),) # cache hit - request_middleware("", []) + inner("", []) assert fresh_spy.call_count == 3 assert fresh_spy.call_args == ((block, allowable_delay),) @@ -138,63 +144,76 @@ def test_stalecheck_adds_block_to_cache(request_middleware, allowable_delay): ) +async def _coro(_method, _params): + return None + + @pytest_asyncio.fixture -async def request_async_middleware(allowable_delay): +async def async_request_middleware(allowable_delay): from unittest.mock import ( AsyncMock, ) - middleware = await async_make_stalecheck_middleware(allowable_delay) - make_request, web3 = AsyncMock(), AsyncMock() - initialized = await middleware(make_request, web3) - # for easier mocking, later: - initialized.web3 = web3 - initialized.make_request = make_request - return initialized + async_web3 = AsyncMock() + middleware = make_stalecheck_middleware(allowable_delay, async_web3) + middleware._w3.provider.make_request = Mock() + return middleware @pytest.mark.asyncio @min_version -async def test_async_stalecheck_pass(request_async_middleware): +async def test_async_stalecheck_pass(async_request_middleware): + from unittest.mock import ( + AsyncMock, + ) + with patch("web3.middleware.stalecheck._is_fresh", return_value=True): + make_request = AsyncMock() + inner = await async_request_middleware._async_wrap_make_request(make_request) + method, params = object(), object() - await request_async_middleware(method, params) - request_async_middleware.make_request.assert_called_once_with(method, params) + await inner(method, params) + + make_request.assert_called_once_with(method, params) @pytest.mark.asyncio @min_version -async def test_async_stalecheck_fail(request_async_middleware, now): +async def test_async_stalecheck_fail(async_request_middleware, now): with patch("web3.middleware.stalecheck._is_fresh", return_value=False): - request_async_middleware.web3.eth.get_block.return_value = stub_block(now) + async_request_middleware._w3.eth.get_block.return_value = stub_block(now) + with pytest.raises(StaleBlockchain): - await request_async_middleware("", []) + inner = await async_request_middleware._async_wrap_make_request(_coro) + await inner("", []) @pytest.mark.asyncio @pytest.mark.parametrize("rpc_method", ["eth_getBlockByNumber"]) @min_version async def test_async_stalecheck_ignores_get_by_block_methods( - request_async_middleware, rpc_method + async_request_middleware, rpc_method ): # This is especially critical for get_block("latest") which would cause # infinite recursion with patch("web3.middleware.stalecheck._is_fresh", side_effect=[False, True]): - await request_async_middleware(rpc_method, []) - assert not request_async_middleware.web3.eth.get_block.called + inner = await async_request_middleware._async_wrap_make_request(_coro) + await inner(rpc_method, []) + assert not async_request_middleware._w3.eth.get_block.called @pytest.mark.asyncio @min_version async def test_async_stalecheck_calls_is_fresh_with_empty_cache( - request_async_middleware, allowable_delay + async_request_middleware, allowable_delay ): with patch( "web3.middleware.stalecheck._is_fresh", side_effect=[False, True] ) as fresh_spy: block = object() - request_async_middleware.web3.eth.get_block.return_value = block - await request_async_middleware("", []) + async_request_middleware._w3.eth.get_block.return_value = block + inner = await async_request_middleware._async_wrap_make_request(_coro) + await inner("", []) cache_call, live_call = fresh_spy.call_args_list assert cache_call[0] == (None, allowable_delay) assert live_call[0] == (block, allowable_delay) @@ -203,22 +222,24 @@ async def test_async_stalecheck_calls_is_fresh_with_empty_cache( @pytest.mark.asyncio @min_version async def test_async_stalecheck_adds_block_to_cache( - request_async_middleware, allowable_delay + async_request_middleware, allowable_delay ): with patch( "web3.middleware.stalecheck._is_fresh", side_effect=[False, True, True] ) as fresh_spy: block = object() - request_async_middleware.web3.eth.get_block.return_value = block + async_request_middleware._w3.eth.get_block.return_value = block + + inner = await async_request_middleware._async_wrap_make_request(_coro) # cache miss - await request_async_middleware("", []) + await inner("", []) cache_call, live_call = fresh_spy.call_args_list assert fresh_spy.call_count == 2 assert cache_call == ((None, allowable_delay),) assert live_call == ((block, allowable_delay),) # cache hit - await request_async_middleware("", []) + await inner("", []) assert fresh_spy.call_count == 3 assert fresh_spy.call_args == ((block, allowable_delay),) diff --git a/tests/core/middleware/test_time_based_cache_middleware.py b/tests/core/middleware/test_time_based_cache_middleware.py deleted file mode 100644 index fa861b9948..0000000000 --- a/tests/core/middleware/test_time_based_cache_middleware.py +++ /dev/null @@ -1,168 +0,0 @@ -import itertools -import pytest -import time -import uuid - -from web3 import ( - Web3, -) -from web3._utils.caching import ( - generate_cache_key, -) -from web3.middleware import ( - construct_error_generator_middleware, - construct_result_generator_middleware, - construct_time_based_cache_middleware, -) -from web3.providers.base import ( - BaseProvider, -) - - -@pytest.fixture -def w3_base(): - return Web3(provider=BaseProvider(), middlewares=[]) - - -@pytest.fixture -def result_generator_middleware(): - return construct_result_generator_middleware( - { - "fake_endpoint": lambda *_: str(uuid.uuid4()), - "not_whitelisted": lambda *_: str(uuid.uuid4()), - } - ) - - -@pytest.fixture -def time_cache_middleware(): - return construct_time_based_cache_middleware( - cache_class=dict, - cache_expire_seconds=10, - rpc_whitelist={"fake_endpoint"}, - ) - - -@pytest.fixture -def w3(w3_base, result_generator_middleware, time_cache_middleware): - w3_base.middleware_onion.add(result_generator_middleware) - w3_base.middleware_onion.add(time_cache_middleware) - return w3_base - - -def test_time_based_cache_middleware_pulls_from_cache(w3_base): - w3 = w3_base - - def cache_class(): - return { - generate_cache_key(("fake_endpoint", [1])): ( - time.time(), - {"result": "value-a"}, - ), - } - - w3.middleware_onion.add( - construct_time_based_cache_middleware( - cache_class=cache_class, - cache_expire_seconds=10, - rpc_whitelist={"fake_endpoint"}, - ) - ) - - assert w3.manager.request_blocking("fake_endpoint", [1]) == "value-a" - - -def test_time_based_cache_middleware_populates_cache(w3): - result = w3.manager.request_blocking("fake_endpoint", []) - - assert w3.manager.request_blocking("fake_endpoint", []) == result - assert w3.manager.request_blocking("fake_endpoint", [1]) != result - - -def test_time_based_cache_middleware_expires_old_values( - w3_base, result_generator_middleware -): - w3 = w3_base - w3.middleware_onion.add(result_generator_middleware) - - def cache_class(): - return { - generate_cache_key(("fake_endpoint", [1])): ( - time.time() - 10, - {"result": "value-a"}, - ), - } - - w3.middleware_onion.add( - construct_time_based_cache_middleware( - cache_class=cache_class, - cache_expire_seconds=10, - rpc_whitelist={"fake_endpoint"}, - ) - ) - - result = w3.manager.request_blocking("fake_endpoint", [1]) - assert result != "value-a" - assert w3.manager.request_blocking("fake_endpoint", [1]) == result - - -@pytest.mark.parametrize( - "response", - ( - {}, - {"result": None}, - ), -) -def test_time_based_cache_middleware_does_not_cache_bad_responses( - w3_base, response, time_cache_middleware -): - w3 = w3_base - counter = itertools.count() - - def mk_result(method, params): - next(counter) - return None - - w3.middleware_onion.add( - construct_result_generator_middleware({"fake_endpoint": mk_result}) - ) - w3.middleware_onion.add(time_cache_middleware) - - w3.manager.request_blocking("fake_endpoint", []) - w3.manager.request_blocking("fake_endpoint", []) - - assert next(counter) == 2 - - -def test_time_based_cache_middleware_does_not_cache_error_response( - w3_base, time_cache_middleware -): - w3 = w3_base - counter = itertools.count() - - def mk_error(method, params): - return f"error-number-{next(counter)}" - - w3.middleware_onion.add( - construct_error_generator_middleware( - { - "fake_endpoint": mk_error, - } - ) - ) - w3.middleware_onion.add(time_cache_middleware) - - with pytest.raises(ValueError) as err: - w3.manager.request_blocking("fake_endpoint", []) - assert "error-number-0" in str(err) - - with pytest.raises(ValueError) as err: - w3.manager.request_blocking("fake_endpoint", []) - assert "error-number-1" in str(err) - - -def test_time_based_cache_middleware_does_not_cache_endpoints_not_in_whitelist(w3): - result_a = w3.manager.request_blocking("not_whitelisted", []) - result_b = w3.manager.request_blocking("not_whitelisted", []) - - assert result_a != result_b diff --git a/tests/core/middleware/test_transaction_signing.py b/tests/core/middleware/test_transaction_signing.py index e1dc9e5a8e..e230556dc7 100644 --- a/tests/core/middleware/test_transaction_signing.py +++ b/tests/core/middleware/test_transaction_signing.py @@ -36,12 +36,9 @@ InvalidAddress, ) from web3.middleware import ( - async_construct_result_generator_middleware, - construct_result_generator_middleware, construct_sign_and_send_raw_middleware, ) from web3.middleware.signing import ( - async_construct_sign_and_send_raw_middleware, gen_normalized_accounts, ) from web3.providers import ( @@ -93,29 +90,21 @@ class DummyProvider(BaseProvider): def make_request(self, method, params): - raise NotImplementedError(f"Cannot make request for {method}:{params}") + raise NotImplementedError(f"Cannot make request for {method}: {params}") @pytest.fixture -def result_generator_middleware(): - return construct_result_generator_middleware( - { +def w3_dummy(request_mocker): + w3_base = Web3(provider=DummyProvider(), middlewares=[]) + with request_mocker( + w3_base, + mock_results={ "eth_sendRawTransaction": lambda *args: args, "net_version": lambda *_: 1, "eth_chainId": lambda *_: "0x02", - } - ) - - -@pytest.fixture -def w3_base(): - return Web3(provider=DummyProvider(), middlewares=[]) - - -@pytest.fixture -def w3_dummy(w3_base, result_generator_middleware): - w3_base.middleware_onion.add(result_generator_middleware) - return w3_base + }, + ): + yield w3_base def hex_to_bytes(s): @@ -424,31 +413,23 @@ def test_sign_and_send_raw_middleware_with_byte_addresses( # -- async -- # -@pytest_asyncio.fixture -async def async_result_generator_middleware(): - return await async_construct_result_generator_middleware( - { - "eth_sendRawTransaction": lambda *args: args, - "net_version": lambda *_: 1, - "eth_chainId": lambda *_: "0x02", - } - ) - - class AsyncDummyProvider(AsyncBaseProvider): async def coro_request(self, method, params): raise NotImplementedError(f"Cannot make request for {method}:{params}") -@pytest.fixture -def async_w3_base(): - return AsyncWeb3(provider=AsyncDummyProvider(), middlewares=[]) - - -@pytest.fixture -def async_w3_dummy(async_w3_base, async_result_generator_middleware): - async_w3_base.middleware_onion.add(async_result_generator_middleware) - return async_w3_base +@pytest_asyncio.fixture +async def async_w3_dummy(request_mocker): + w3_base = AsyncWeb3(provider=AsyncDummyProvider(), middlewares=[]) + async with request_mocker( + w3_base, + mock_results={ + "eth_sendRawTransaction": lambda *args: args, + "net_version": 1, + "eth_chainId": "0x02", + }, + ): + yield w3_base @pytest.fixture @@ -482,7 +463,7 @@ async def test_async_sign_and_send_raw_middleware( key_object, ): async_w3_dummy.middleware_onion.add( - await async_construct_sign_and_send_raw_middleware(key_object) + construct_sign_and_send_raw_middleware(key_object) ) legacy_transaction = { @@ -553,9 +534,7 @@ async def test_async_signed_transaction( key_object, from_, ): - async_w3.middleware_onion.add( - await async_construct_sign_and_send_raw_middleware(key_object) - ) + async_w3.middleware_onion.add(construct_sign_and_send_raw_middleware(key_object)) # Drop any falsy addresses accounts = await async_w3.eth.accounts @@ -595,7 +574,7 @@ async def test_async_sign_and_send_raw_middleware_with_byte_addresses( to_ = to_converter(ADDRESS_2) async_w3_dummy.middleware_onion.add( - await async_construct_sign_and_send_raw_middleware(private_key) + construct_sign_and_send_raw_middleware(private_key) ) actual = await async_w3_dummy.manager.coro_request( diff --git a/tests/core/providers/test_async_http_provider.py b/tests/core/providers/test_async_http_provider.py index 52b9397f01..dbd1d8d79d 100644 --- a/tests/core/providers/test_async_http_provider.py +++ b/tests/core/providers/test_async_http_provider.py @@ -23,16 +23,16 @@ AsyncGethTxPool, ) from web3.middleware import ( - async_attrdict_middleware, - async_buffered_gas_estimate_middleware, - async_gas_price_strategy_middleware, - async_name_to_address_middleware, - async_validation_middleware, + attrdict_middleware, + buffered_gas_estimate_middleware, + ens_name_to_address_middleware, + gas_price_strategy_middleware, + validation_middleware, ) from web3.net import ( AsyncNet, ) -from web3.providers.async_rpc import ( +from web3.providers.rpc import ( AsyncHTTPProvider, ) @@ -86,17 +86,17 @@ def test_web3_with_async_http_provider_has_default_middlewares_and_modules() -> assert ( async_w3.middleware_onion.get("gas_price_strategy") - == async_gas_price_strategy_middleware + == gas_price_strategy_middleware ) assert ( - async_w3.middleware_onion.get("name_to_address") - == async_name_to_address_middleware + async_w3.middleware_onion.get("ens_name_to_address") + == ens_name_to_address_middleware ) - assert async_w3.middleware_onion.get("attrdict") == async_attrdict_middleware - assert async_w3.middleware_onion.get("validation") == async_validation_middleware + assert async_w3.middleware_onion.get("attrdict") == attrdict_middleware + assert async_w3.middleware_onion.get("validation") == validation_middleware assert ( async_w3.middleware_onion.get("gas_estimate") - == async_buffered_gas_estimate_middleware + == buffered_gas_estimate_middleware ) diff --git a/tests/core/providers/test_http_provider.py b/tests/core/providers/test_http_provider.py index 53308e3afa..4567d18fe9 100644 --- a/tests/core/providers/test_http_provider.py +++ b/tests/core/providers/test_http_provider.py @@ -26,11 +26,10 @@ GethTxPool, ) from web3.middleware import ( - abi_middleware, attrdict_middleware, buffered_gas_estimate_middleware, + ens_name_to_address_middleware, gas_price_strategy_middleware, - name_to_address_middleware, validation_middleware, ) from web3.net import ( @@ -80,19 +79,17 @@ def test_web3_with_http_provider_has_default_middlewares_and_modules() -> None: # the following length check should fail and will need to be added to once more # middlewares are added to the defaults - assert len(w3.middleware_onion.middlewares) == 6 + assert len(w3.middleware_onion.middlewares) == 5 assert ( w3.middleware_onion.get("gas_price_strategy") == gas_price_strategy_middleware ) assert ( - w3.middleware_onion.get("name_to_address").__name__ - == name_to_address_middleware(w3).__name__ + w3.middleware_onion.get("ens_name_to_address") == ens_name_to_address_middleware ) assert w3.middleware_onion.get("attrdict") == attrdict_middleware assert w3.middleware_onion.get("validation") == validation_middleware assert w3.middleware_onion.get("gas_estimate") == buffered_gas_estimate_middleware - assert w3.middleware_onion.get("abi") == abi_middleware def test_user_provided_session(): diff --git a/tests/core/providers/test_http_request_retry.py b/tests/core/providers/test_http_request_retry.py new file mode 100644 index 0000000000..938eb963dc --- /dev/null +++ b/tests/core/providers/test_http_request_retry.py @@ -0,0 +1,191 @@ +import pytest +from unittest.mock import ( + patch, +) + +import aiohttp +from requests.exceptions import ( + ConnectionError, + HTTPError, + Timeout, + TooManyRedirects, +) + +from web3 import ( + AsyncHTTPProvider, + AsyncWeb3, + Web3, + WebsocketProviderV2, +) +from web3.providers import ( + HTTPProvider, + IPCProvider, +) +from web3.providers.rpc.utils import ( + ExceptionRetryConfiguration, + check_if_retry_on_failure, +) +from web3.types import ( + RPCEndpoint, +) + +TEST_RETRY_COUNT = 3 + + +@pytest.fixture +def w3(): + errors = (ConnectionError, HTTPError, Timeout, TooManyRedirects) + config = ExceptionRetryConfiguration(errors=errors, retries=TEST_RETRY_COUNT) + return Web3(HTTPProvider(exception_retry_configuration=config)) + + +def test_default_request_retry_configuration_for_http_provider(): + w3 = Web3(HTTPProvider()) + assert ( + getattr(w3.provider, "exception_retry_configuration") + == ExceptionRetryConfiguration() + ) + + +def test_check_if_retry_on_failure_false(): + methods = [ + "eth_sendTransaction", + "personal_signAndSendTransaction", + "personal_sendTransaction", + ] + + for method in methods: + assert not check_if_retry_on_failure(method) + + +def test_check_if_retry_on_failure_true(): + method = "eth_getBalance" + assert check_if_retry_on_failure(method) + + +@patch("web3.providers.rpc.rpc.make_post_request", side_effect=ConnectionError) +def test_check_send_transaction_called_once(make_post_request_mock, w3): + with pytest.raises(ConnectionError): + w3.provider.make_request( + RPCEndpoint("eth_sendTransaction"), [{"to": f"0x{'00' * 20}", "value": 1}] + ) + assert make_post_request_mock.call_count == 1 + + +@patch("web3.providers.rpc.rpc.make_post_request", side_effect=ConnectionError) +def test_valid_method_retried(make_post_request_mock, w3): + with pytest.raises(ConnectionError): + w3.provider.make_request(RPCEndpoint("eth_getBalance"), [f"0x{'00' * 20}"]) + assert make_post_request_mock.call_count == TEST_RETRY_COUNT + + +def test_exception_retry_config_is_strictly_on_http_provider(): + w3 = Web3(HTTPProvider()) + assert hasattr(w3.provider, "exception_retry_configuration") + + w3 = Web3(IPCProvider()) + assert not hasattr(w3.provider, "exception_retry_configuration") + + w3 = AsyncWeb3.persistent_websocket(WebsocketProviderV2("ws://localhost:8546")) + assert not hasattr(w3.provider, "exception_retry_configuration") + + +@patch("web3.providers.rpc.rpc.make_post_request", side_effect=ConnectionError) +def test_exception_retry_middleware_with_allow_list_kwarg(make_post_request_mock): + config = ExceptionRetryConfiguration( + errors=(ConnectionError, HTTPError, Timeout, TooManyRedirects), + retries=TEST_RETRY_COUNT, + method_allowlist=["test_userProvidedMethod"], + ) + w3 = Web3(HTTPProvider(exception_retry_configuration=config)) + + with pytest.raises(ConnectionError): + w3.provider.make_request(RPCEndpoint("test_userProvidedMethod"), []) + assert make_post_request_mock.call_count == TEST_RETRY_COUNT + + make_post_request_mock.reset_mock() + with pytest.raises(ConnectionError): + w3.provider.make_request(RPCEndpoint("eth_getBalance"), []) + assert make_post_request_mock.call_count == 1 + + +# -- async -- # + + +@pytest.fixture +def async_w3(): + errors = (aiohttp.ClientError, TimeoutError) + config = ExceptionRetryConfiguration(errors=errors, retries=TEST_RETRY_COUNT) + return AsyncWeb3(AsyncHTTPProvider(exception_retry_configuration=config)) + + +@pytest.mark.asyncio +async def test_async_default_request_retry_configuration_for_http_provider(): + async_w3 = AsyncWeb3(AsyncHTTPProvider()) + assert ( + getattr(async_w3.provider, "exception_retry_configuration") + == ExceptionRetryConfiguration() + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "error", + ( + TimeoutError, + aiohttp.ClientError, # general base class for all aiohttp errors + # ClientError subclasses + aiohttp.ClientConnectionError, + aiohttp.ServerTimeoutError, + aiohttp.ClientOSError, + ), +) +async def test_async_check_retry_middleware(async_w3, error): + with patch( + "web3.providers.rpc.async_rpc.async_make_post_request" + ) as async_make_post_request_mock: + async_make_post_request_mock.side_effect = error + + with pytest.raises(error): + await async_w3.provider.make_request(RPCEndpoint("eth_getBalance"), []) + assert async_make_post_request_mock.call_count == TEST_RETRY_COUNT + + +@pytest.mark.asyncio +async def test_async_check_without_retry_config(): + w3 = AsyncWeb3(AsyncHTTPProvider(exception_retry_configuration=None)) + + with patch( + "web3.providers.rpc.async_rpc.async_make_post_request" + ) as async_make_post_request_mock: + async_make_post_request_mock.side_effect = TimeoutError + + with pytest.raises(TimeoutError): + await w3.eth.block_number + assert async_make_post_request_mock.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_exception_retry_middleware_with_allow_list_kwarg(): + config = ExceptionRetryConfiguration( + errors=(aiohttp.ClientError, TimeoutError), + retries=TEST_RETRY_COUNT, + method_allowlist=["test_userProvidedMethod"], + ) + async_w3 = AsyncWeb3(AsyncHTTPProvider(exception_retry_configuration=config)) + + with patch( + "web3.providers.rpc.async_rpc.async_make_post_request" + ) as async_make_post_request_mock: + async_make_post_request_mock.side_effect = TimeoutError + + with pytest.raises(TimeoutError): + await async_w3.provider.make_request( + RPCEndpoint("test_userProvidedMethod"), [] + ) + assert async_make_post_request_mock.call_count == TEST_RETRY_COUNT + + async_make_post_request_mock.reset_mock() + with pytest.raises(TimeoutError): + await async_w3.provider.make_request(RPCEndpoint("eth_getBalance"), []) + assert async_make_post_request_mock.call_count == 1 diff --git a/tests/core/providers/test_ipc_provider.py b/tests/core/providers/test_ipc_provider.py index 2d24e49a2c..764f8f5833 100644 --- a/tests/core/providers/test_ipc_provider.py +++ b/tests/core/providers/test_ipc_provider.py @@ -14,12 +14,12 @@ from web3.exceptions import ( ProviderConnectionError, ) -from web3.middleware import ( - construct_fixture_middleware, -) from web3.providers.ipc import ( IPCProvider, ) +from web3.types import ( + RPCEndpoint, +) @pytest.fixture @@ -90,14 +90,15 @@ def test_sync_waits_for_full_result(jsonrpc_ipc_pipe_path, serve_empty_result): provider._socket.sock.close() -def test_web3_auto_gethdev(): +def test_web3_auto_gethdev(request_mocker): assert isinstance(w3.provider, IPCProvider) - return_block_with_long_extra_data = construct_fixture_middleware( - { - "eth_getBlockByNumber": {"extraData": "0x" + "ff" * 33}, - } - ) - w3.middleware_onion.inject(return_block_with_long_extra_data, layer=0) - block = w3.eth.get_block("latest") + with request_mocker( + w3, + mock_results={ + RPCEndpoint("eth_getBlockByNumber"): {"extraData": "0x" + "ff" * 33} + }, + ): + block = w3.eth.get_block("latest") + assert "extraData" not in block assert block.proofOfAuthorityData == b"\xff" * 33 diff --git a/tests/core/utilities/test_fee_utils.py b/tests/core/utilities/test_fee_utils.py index 84a1fb743d..a9bbdb1639 100644 --- a/tests/core/utilities/test_fee_utils.py +++ b/tests/core/utilities/test_fee_utils.py @@ -4,14 +4,6 @@ is_integer, ) -from web3.middleware import ( - construct_error_generator_middleware, - construct_result_generator_middleware, -) -from web3.types import ( - RPCEndpoint, -) - @pytest.mark.parametrize( "fee_history_rewards,expected_max_prio_calc", @@ -37,28 +29,20 @@ ) # Test fee_utils indirectly by mocking eth_feeHistory results # and checking against expected output -def test_fee_utils_indirectly(w3, fee_history_rewards, expected_max_prio_calc) -> None: - fail_max_prio_middleware = construct_error_generator_middleware( - {RPCEndpoint("eth_maxPriorityFeePerGas"): lambda *_: ""} - ) - fee_history_result_middleware = construct_result_generator_middleware( - {RPCEndpoint("eth_feeHistory"): lambda *_: {"reward": fee_history_rewards}} - ) - - w3.middleware_onion.add(fail_max_prio_middleware, "fail_max_prio") - w3.middleware_onion.inject( - fee_history_result_middleware, "fee_history_result", layer=0 - ) - +def test_fee_utils_indirectly( + w3, fee_history_rewards, expected_max_prio_calc, request_mocker +) -> None: with pytest.warns( UserWarning, match="There was an issue with the method eth_maxPriorityFeePerGas. " "Calculating using eth_feeHistory.", ): - max_priority_fee = w3.eth.max_priority_fee + with request_mocker( + w3, + mock_errors={"eth_maxPriorityFeePerGas": {}}, + mock_results={"eth_feeHistory": {"reward": fee_history_rewards}}, + ): + max_priority_fee = w3.eth.max_priority_fee + assert is_integer(max_priority_fee) assert max_priority_fee == expected_max_prio_calc - - # clean up - w3.middleware_onion.remove("fail_max_prio") - w3.middleware_onion.remove("fee_history_result") diff --git a/tests/ens/conftest.py b/tests/ens/conftest.py index 919eb463b3..793195b4e4 100644 --- a/tests/ens/conftest.py +++ b/tests/ens/conftest.py @@ -362,8 +362,7 @@ def TEST_ADDRESS(address_conversion_func): @pytest_asyncio.fixture(scope="session") def async_w3(): - provider = AsyncEthereumTesterProvider() - _async_w3 = AsyncWeb3(provider, middlewares=provider.middlewares) + _async_w3 = AsyncWeb3(AsyncEthereumTesterProvider()) return _async_w3 diff --git a/tests/ens/test_ens.py b/tests/ens/test_ens.py index af9cf99833..d5dea4397d 100644 --- a/tests/ens/test_ens.py +++ b/tests/ens/test_ens.py @@ -15,8 +15,8 @@ AsyncWeb3, ) from web3.middleware import ( - async_validation_middleware, gas_price_strategy_middleware, + validation_middleware, ) from web3.providers.eth_tester import ( AsyncEthereumTesterProvider, @@ -157,7 +157,7 @@ def local_async_w3(): def test_async_from_web3_inherits_web3_middlewares(local_async_w3): - test_middleware = async_validation_middleware + test_middleware = validation_middleware local_async_w3.middleware_onion.add(test_middleware, "test_middleware") ns = AsyncENS.from_web3(local_async_w3) diff --git a/tests/ens/test_utils.py b/tests/ens/test_utils.py index 0dd6948b09..66d9d37c26 100644 --- a/tests/ens/test_utils.py +++ b/tests/ens/test_utils.py @@ -39,8 +39,8 @@ def test_init_web3_adds_expected_middlewares(): w3 = init_web3() - middlewares = map(str, w3.manager.middleware_onion) - assert "stalecheck_middleware" in next(middlewares) + middlewares = w3.middleware_onion.middlewares + assert middlewares[0][1] == "stalecheck" @pytest.mark.parametrize( @@ -207,8 +207,8 @@ def test_label_to_hash_normalizes_name_using_ensip15(): @pytest.mark.asyncio async def test_init_async_web3_adds_expected_async_middlewares(): async_w3 = init_async_web3() - middlewares = map(str, async_w3.manager.middleware_onion) - assert "stalecheck_middleware" in next(middlewares) + middlewares = async_w3.middleware_onion.middlewares + assert middlewares[0][1] == "stalecheck" @pytest.mark.asyncio diff --git a/tests/integration/go_ethereum/test_goethereum_http.py b/tests/integration/go_ethereum/test_goethereum_http.py index 5013e7cf80..9963854437 100644 --- a/tests/integration/go_ethereum/test_goethereum_http.py +++ b/tests/integration/go_ethereum/test_goethereum_http.py @@ -18,7 +18,7 @@ from web3._utils.module_testing.go_ethereum_personal_module import ( GoEthereumAsyncPersonalModuleTest, ) -from web3.providers.async_rpc import ( +from web3.providers.rpc import ( AsyncHTTPProvider, ) diff --git a/tests/integration/test_ethereum_tester.py b/tests/integration/test_ethereum_tester.py index cc5d8a4fbc..dfcc3fc97c 100644 --- a/tests/integration/test_ethereum_tester.py +++ b/tests/integration/test_ethereum_tester.py @@ -266,14 +266,6 @@ def func_wrapper(self, eth_tester, *args, **kwargs): class TestEthereumTesterEthModule(EthModuleTest): - test_eth_max_priority_fee_with_fee_history_calculation = not_implemented( - EthModuleTest.test_eth_max_priority_fee_with_fee_history_calculation, - MethodUnavailable, - ) - test_eth_max_priority_fee_with_fee_history_calculation_error_dict = not_implemented( - EthModuleTest.test_eth_max_priority_fee_with_fee_history_calculation_error_dict, - ValueError, - ) test_eth_sign = not_implemented(EthModuleTest.test_eth_sign, MethodUnavailable) test_eth_sign_ens_names = not_implemented( EthModuleTest.test_eth_sign_ens_names, MethodUnavailable diff --git a/web3/__init__.py b/web3/__init__.py index aa2ba141f6..4406207076 100644 --- a/web3/__init__.py +++ b/web3/__init__.py @@ -15,9 +15,6 @@ AsyncWeb3, Web3, ) -from web3.providers.async_rpc import ( # noqa: E402 - AsyncHTTPProvider, -) from web3.providers.eth_tester import ( # noqa: E402 EthereumTesterProvider, ) @@ -25,6 +22,7 @@ IPCProvider, ) from web3.providers.rpc import ( # noqa: E402 + AsyncHTTPProvider, HTTPProvider, ) from web3.providers.websocket import ( # noqa: E402 diff --git a/web3/_utils/caching.py b/web3/_utils/caching.py index e1339d6085..a10e901cb4 100644 --- a/web3/_utils/caching.py +++ b/web3/_utils/caching.py @@ -1,11 +1,15 @@ import collections import hashlib +import threading from typing import ( TYPE_CHECKING, Any, Callable, + Coroutine, List, Tuple, + TypeVar, + Union, ) from eth_utils import ( @@ -20,11 +24,22 @@ ) if TYPE_CHECKING: - from web3.types import ( + from web3.providers import ( # noqa: F401 + AsyncBaseProvider, + BaseProvider, + ) + from web3.types import ( # noqa: F401 + AsyncMakeRequestFn, + MakeRequestFn, RPCEndpoint, + RPCResponse, ) +SYNC_PROVIDER_TYPE = TypeVar("SYNC_PROVIDER_TYPE", bound="BaseProvider") +ASYNC_PROVIDER_TYPE = TypeVar("ASYNC_PROVIDER_TYPE", bound="AsyncBaseProvider") + + def generate_cache_key(value: Any) -> str: """ Generates a cache key for the *args and **kwargs @@ -58,3 +73,83 @@ def __init__( self.response_formatters = response_formatters self.subscription_id = subscription_id self.middleware_response_processors: List[Callable[..., Any]] = [] + + +def is_cacheable_request( + provider: Union[ASYNC_PROVIDER_TYPE, SYNC_PROVIDER_TYPE], method: "RPCEndpoint" +) -> bool: + if provider.cache_allowed_requests and method in provider.cacheable_requests: + return True + return False + + +# -- request caching decorators -- # + + +def _should_cache_response(response: "RPCResponse") -> bool: + return ( + "error" not in response + and "result" in response + and not is_null(response["result"]) + ) + + +def handle_request_caching( + func: Callable[[SYNC_PROVIDER_TYPE, "RPCEndpoint", Any], "RPCResponse"] +) -> Callable[..., "RPCResponse"]: + def wrapper( + provider: SYNC_PROVIDER_TYPE, method: "RPCEndpoint", params: Any + ) -> "RPCResponse": + if is_cacheable_request(provider, method): + request_cache = provider._request_cache + cache_key = generate_cache_key( + f"{threading.get_ident()}:{(method, params)}" + ) + cache_result = request_cache.get_cache_entry(cache_key) + if cache_result is not None: + return cache_result + else: + response = func(provider, method, params) + if _should_cache_response(response): + with provider._request_cache_lock: + request_cache.cache(cache_key, response) + return response + else: + return func(provider, method, params) + + # save a reference to the decorator on the wrapped function + wrapper._decorator = handle_request_caching # type: ignore + return wrapper + + +# -- async -- # + + +def async_handle_request_caching( + func: Callable[ + [ASYNC_PROVIDER_TYPE, "RPCEndpoint", Any], Coroutine[Any, Any, "RPCResponse"] + ], +) -> Callable[..., Coroutine[Any, Any, "RPCResponse"]]: + async def wrapper( + provider: ASYNC_PROVIDER_TYPE, method: "RPCEndpoint", params: Any + ) -> "RPCResponse": + if is_cacheable_request(provider, method): + request_cache = provider._request_cache + cache_key = generate_cache_key( + f"{threading.get_ident()}:{(method, params)}" + ) + cache_result = request_cache.get_cache_entry(cache_key) + if cache_result is not None: + return cache_result + else: + response = await func(provider, method, params) + if _should_cache_response(response): + async with provider._request_cache_lock: + request_cache.cache(cache_key, response) + return response + else: + return await func(provider, method, params) + + # save a reference to the decorator on the wrapped function + wrapper._decorator = async_handle_request_caching # type: ignore + return wrapper diff --git a/web3/_utils/ens.py b/web3/_utils/ens.py index a1bd8a1437..5db42531de 100644 --- a/web3/_utils/ens.py +++ b/web3/_utils/ens.py @@ -17,6 +17,7 @@ is_0x_prefixed, is_hex, is_hex_address, + to_checksum_address, ) from ens import ( @@ -51,7 +52,7 @@ def is_ens_name(value: Any) -> bool: def validate_name_has_address(ens: ENS, name: str) -> ChecksumAddress: addr = ens.address(name) if addr: - return addr + return to_checksum_address(addr) else: raise NameNotFound(f"Could not find address for name {name!r}") diff --git a/web3/_utils/module_testing/eth_module.py b/web3/_utils/module_testing/eth_module.py index 5f479e060f..8fb8f89b31 100644 --- a/web3/_utils/module_testing/eth_module.py +++ b/web3/_utils/module_testing/eth_module.py @@ -10,6 +10,7 @@ Any, Callable, List, + Type, Union, cast, ) @@ -49,6 +50,9 @@ from web3._utils.ens import ( ens_addresses, ) +from web3._utils.fee_utils import ( + PRIORITY_FEE_MIN, +) from web3._utils.method_formatters import ( to_hex_if_integer, ) @@ -58,6 +62,9 @@ mine_pending_block, mock_offchain_lookup_request_response, ) +from web3._utils.module_testing.utils import ( + RequestMocker, +) from web3._utils.type_conversion import ( to_hex_if_bytes, ) @@ -78,12 +85,7 @@ Web3ValidationError, ) from web3.middleware import ( - async_geth_poa_middleware, -) -from web3.middleware.fixture import ( - async_construct_error_generator_middleware, - async_construct_result_generator_middleware, - construct_error_generator_middleware, + extradata_to_poa_middleware, ) from web3.types import ( ENS, @@ -648,27 +650,23 @@ async def test_validation_middleware_chain_id_mismatch( await async_w3.eth.send_transaction(txn_params) @pytest.mark.asyncio - async def test_geth_poa_middleware(self, async_w3: "AsyncWeb3") -> None: - return_block_with_long_extra_data = ( - await async_construct_result_generator_middleware( - { - RPCEndpoint("eth_getBlockByNumber"): lambda *_: { - "extraData": "0x" + "ff" * 33 - }, - } - ) - ) - async_w3.middleware_onion.inject(async_geth_poa_middleware, "poa", layer=0) - async_w3.middleware_onion.inject( - return_block_with_long_extra_data, "extradata", layer=0 - ) - block = await async_w3.eth.get_block("latest") + async def test_extradata_to_poa_middleware( + self, async_w3: "AsyncWeb3", request_mocker: Type[RequestMocker] + ) -> None: + async_w3.middleware_onion.inject(extradata_to_poa_middleware, "poa", layer=0) + extra_data = f"0x{'ff' * 33}" + + async with request_mocker( + async_w3, + mock_results={"eth_getBlockByNumber": {"extraData": extra_data}}, + ): + block = await async_w3.eth.get_block("latest") + assert "extraData" not in block - assert block["proofOfAuthorityData"] == b"\xff" * 33 + assert block["proofOfAuthorityData"] == to_bytes(hexstr=extra_data) # clean up async_w3.middleware_onion.remove("poa") - async_w3.middleware_onion.remove("extradata") @pytest.mark.asyncio async def test_eth_send_raw_transaction(self, async_w3: "AsyncWeb3") -> None: @@ -852,60 +850,25 @@ async def test_eth_max_priority_fee(self, async_w3: "AsyncWeb3") -> None: max_priority_fee = await async_w3.eth.max_priority_fee assert is_integer(max_priority_fee) - @pytest.mark.asyncio - async def test_eth_max_priority_fee_with_fee_history_calculation_error_dict( - self, async_w3: "AsyncWeb3" - ) -> None: - fail_max_prio_middleware = await async_construct_error_generator_middleware( - { - RPCEndpoint("eth_maxPriorityFeePerGas"): lambda *_: { - "error": { - "code": -32601, - "message": ( - "The method eth_maxPriorityFeePerGas does " - "not exist/is not available" - ), - } - } - } - ) - async_w3.middleware_onion.add( - fail_max_prio_middleware, name="fail_max_prio_middleware" - ) - - with pytest.warns( - UserWarning, - match=( - "There was an issue with the method eth_maxPriorityFeePerGas." - " Calculating using eth_feeHistory." - ), - ): - await async_w3.eth.max_priority_fee - - async_w3.middleware_onion.remove("fail_max_prio_middleware") # clean up - @pytest.mark.asyncio async def test_eth_max_priority_fee_with_fee_history_calculation( - self, async_w3: "AsyncWeb3" + self, async_w3: "AsyncWeb3", request_mocker: Type[RequestMocker] ) -> None: - fail_max_prio_middleware = await async_construct_error_generator_middleware( - {RPCEndpoint("eth_maxPriorityFeePerGas"): lambda *_: ""} - ) - async_w3.middleware_onion.add( - fail_max_prio_middleware, name="fail_max_prio_middleware" - ) - - with pytest.warns( - UserWarning, - match=( - "There was an issue with the method eth_maxPriorityFeePerGas. " - "Calculating using eth_feeHistory." - ), + async with request_mocker( + async_w3, + mock_errors={RPCEndpoint("eth_maxPriorityFeePerGas"): {}}, + mock_results={RPCEndpoint("eth_feeHistory"): {"reward": [[0]]}}, ): - max_priority_fee = await async_w3.eth.max_priority_fee - assert is_integer(max_priority_fee) - - async_w3.middleware_onion.remove("fail_max_prio_middleware") # clean up + with pytest.warns( + UserWarning, + match=( + "There was an issue with the method eth_maxPriorityFeePerGas. " + "Calculating using eth_feeHistory." + ), + ): + priority_fee = await async_w3.eth.max_priority_fee + assert is_integer(priority_fee) + assert priority_fee == PRIORITY_FEE_MIN @pytest.mark.asyncio async def test_eth_getBlockByHash( @@ -2480,58 +2443,24 @@ def test_eth_max_priority_fee(self, w3: "Web3") -> None: max_priority_fee = w3.eth.max_priority_fee assert is_integer(max_priority_fee) - def test_eth_max_priority_fee_with_fee_history_calculation_error_dict( - self, w3: "Web3" - ) -> None: - fail_max_prio_middleware = construct_error_generator_middleware( - { - RPCEndpoint("eth_maxPriorityFeePerGas"): lambda *_: { - "error": { - "code": -32601, - "message": ( - "The method eth_maxPriorityFeePerGas does " - "not exist/is not available" - ), - } - } - } - ) - w3.middleware_onion.add( - fail_max_prio_middleware, name="fail_max_prio_middleware" - ) - - with pytest.warns( - UserWarning, - match=( - "There was an issue with the method eth_maxPriorityFeePerGas." - " Calculating using eth_feeHistory." - ), - ): - w3.eth.max_priority_fee - - w3.middleware_onion.remove("fail_max_prio_middleware") # clean up - def test_eth_max_priority_fee_with_fee_history_calculation( - self, w3: "Web3" + self, w3: "Web3", request_mocker: Type[RequestMocker] ) -> None: - fail_max_prio_middleware = construct_error_generator_middleware( - {RPCEndpoint("eth_maxPriorityFeePerGas"): lambda *_: ""} - ) - w3.middleware_onion.add( - fail_max_prio_middleware, name="fail_max_prio_middleware" - ) - - with pytest.warns( - UserWarning, - match=( - "There was an issue with the method eth_maxPriorityFeePerGas." - " Calculating using eth_feeHistory." - ), + with request_mocker( + w3, + mock_errors={RPCEndpoint("eth_maxPriorityFeePerGas"): {}}, + mock_results={RPCEndpoint("eth_feeHistory"): {"reward": [[0]]}}, ): - max_priority_fee = w3.eth.max_priority_fee - assert is_integer(max_priority_fee) - - w3.middleware_onion.remove("fail_max_prio_middleware") # clean up + with pytest.warns( + UserWarning, + match=( + "There was an issue with the method eth_maxPriorityFeePerGas. " + "Calculating using eth_feeHistory." + ), + ): + max_priority_fee = w3.eth.max_priority_fee + assert is_integer(max_priority_fee) + assert max_priority_fee == PRIORITY_FEE_MIN def test_eth_accounts(self, w3: "Web3") -> None: accounts = w3.eth.accounts diff --git a/web3/_utils/module_testing/persistent_connection_provider.py b/web3/_utils/module_testing/persistent_connection_provider.py index b07416c029..4c6ca65e86 100644 --- a/web3/_utils/module_testing/persistent_connection_provider.py +++ b/web3/_utils/module_testing/persistent_connection_provider.py @@ -21,7 +21,7 @@ AttributeDict, ) from web3.middleware import ( - async_geth_poa_middleware, + extradata_to_poa_middleware, ) from web3.types import ( FormattedEthSubscriptionResponse, @@ -34,7 +34,7 @@ def _mocked_recv(sub_id: str, ws_subscription_response: Dict[str, Any]) -> bytes: - # Must be same subscription id so we can know how to parse the message. + # Must be same subscription id, so we can know how to parse the message. # We don't have this information when mocking the response. ws_subscription_response["params"]["subscription"] = sub_id return to_bytes(text=json.dumps(ws_subscription_response)) @@ -316,12 +316,12 @@ async def _mocked_recv_coro() -> bytes: async_w3.provider._ws.__setattr__("recv", actual_recv_fxn) @pytest.mark.asyncio - async def test_async_geth_poa_middleware_on_eth_subscription( + async def test_async_extradata_to_poa_middleware_on_eth_subscription( self, async_w3: "_PersistentConnectionWeb3", ) -> None: async_w3.middleware_onion.inject( - async_geth_poa_middleware, "poa_middleware", layer=0 + extradata_to_poa_middleware, "poa_middleware", layer=0 ) sub_id = await async_w3.eth.subscribe("newHeads") diff --git a/web3/_utils/module_testing/utils.py b/web3/_utils/module_testing/utils.py new file mode 100644 index 0000000000..b2e30ea274 --- /dev/null +++ b/web3/_utils/module_testing/utils.py @@ -0,0 +1,194 @@ +from asyncio import ( + iscoroutinefunction, +) +import copy +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Union, + cast, +) + +from toolz import ( + merge, +) + +if TYPE_CHECKING: + from web3 import ( # noqa: F401 + AsyncWeb3, + Web3, + ) + from web3._utils.compat import ( # noqa: F401 + Self, + ) + from web3.types import ( # noqa: F401 + AsyncMakeRequestFn, + MakeRequestFn, + RPCEndpoint, + RPCResponse, + ) + + +class RequestMocker: + """ + Context manager to mock requests made by a web3 instance. This is meant to be used + via a ``request_mocker`` fixture defined within the appropriate context. + + Example: + + def test_my_w3(w3, request_mocker): + assert w3.eth.block_number == 0 + + with request_mocker(w3, mock_results={"eth_blockNumber": "0x1"}): + assert w3.eth.block_number == 1 + + assert w3.eth.block_number == 0 + + ``mock_results`` is a dict mapping method names to the desired "result" object of + the RPC response. ``mock_errors`` is a dict mapping method names to the desired + "error" object of the RPC response. If a method name is not in either dict, + the request is made as usual. + """ + + def __init__( + self, + w3: Union["AsyncWeb3", "Web3"], + mock_results: Dict[Union["RPCEndpoint", str], Any] = None, + mock_errors: Dict[Union["RPCEndpoint", str], Any] = None, + ): + self.w3 = w3 + self.mock_results = mock_results or {} + self.mock_errors = mock_errors or {} + self._make_request: Union[ + "AsyncMakeRequestFn", "MakeRequestFn" + ] = w3.provider.make_request + + def __enter__(self) -> "Self": + setattr(self.w3.provider, "make_request", self._mock_request_handler) + # reset request func cache to re-build request_func with mocked make_request + self.w3.provider._request_func_cache = (None, None) + + return self + + # define __exit__ with typing information + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + setattr(self.w3.provider, "make_request", self._make_request) + # reset request func cache to re-build request_func with original make_request + self.w3.provider._request_func_cache = (None, None) + + def _mock_request_handler( + self, method: "RPCEndpoint", params: Any + ) -> "RPCResponse": + self.w3 = cast("Web3", self.w3) + self._make_request = cast("MakeRequestFn", self._make_request) + + if method not in self.mock_errors and method not in self.mock_results: + return self._make_request(method, params) + + request_id = ( + next(copy.deepcopy(self.w3.provider.request_counter)) + if hasattr(self.w3.provider, "request_counter") + else 1 + ) + response_dict = {"jsonrpc": "2.0", "id": request_id} + + if method in self.mock_results: + mock_return = self.mock_results[method] + if callable(mock_return): + mock_return = mock_return(method, params) + mocked_response = merge(response_dict, {"result": mock_return}) + elif method in self.mock_errors: + error = self.mock_errors[method] + if callable(error): + error = error(method, params) + code = error.get("code", -32000) + message = error.get("message", "Mocked error") + mocked_response = merge( + response_dict, + {"error": merge({"code": code, "message": message}, error)}, + ) + else: + raise Exception("Invariant: unreachable code path") + + decorator = getattr(self._make_request, "_decorator", None) + if decorator is not None: + # If the original make_request was decorated, we need to re-apply + # the decorator to the mocked make_request. This is necessary for + # the request caching decorator to work properly. + return decorator(lambda *_: mocked_response)( + self.w3.provider, method, params + ) + else: + return mocked_response + + # -- async -- # + async def __aenter__(self) -> "Self": + setattr(self.w3.provider, "make_request", self._async_mock_request_handler) + # reset request func cache to re-build request_func with mocked make_request + self.w3.provider._request_func_cache = (None, None) + return self + + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + setattr(self.w3.provider, "make_request", self._make_request) + # reset request func cache to re-build request_func with original make_request + self.w3.provider._request_func_cache = (None, None) + + async def _async_mock_request_handler( + self, method: "RPCEndpoint", params: Any + ) -> "RPCResponse": + self.w3 = cast("AsyncWeb3", self.w3) + self._make_request = cast("AsyncMakeRequestFn", self._make_request) + + if method not in self.mock_errors and method not in self.mock_results: + return await self._make_request(method, params) + + request_id = ( + next(copy.deepcopy(self.w3.provider.request_counter)) + if hasattr(self.w3.provider, "request_counter") + else 1 + ) + response_dict = {"jsonrpc": "2.0", "id": request_id} + + if method in self.mock_results: + mock_return = self.mock_results[method] + if callable(mock_return): + # handle callable to make things easier since we're mocking + mock_return = mock_return(method, params) + elif iscoroutinefunction(mock_return): + # this is the "correct" way to mock the async make_request + mock_return = await mock_return(method, params) + + mocked_result = merge(response_dict, {"result": mock_return}) + + elif method in self.mock_errors: + error = self.mock_errors[method] + if callable(error): + error = error(method, params) + elif iscoroutinefunction(error): + error = await error(method, params) + + code = error.get("code", -32000) + message = error.get("message", "Mocked error") + mocked_result = merge( + response_dict, + {"error": merge({"code": code, "message": message}, error)}, + ) + + else: + raise Exception("Invariant: unreachable code path") + + decorator = getattr(self._make_request, "_decorator", None) + if decorator is not None: + # If the original make_request was decorated, we need to re-apply + # the decorator to the mocked make_request. This is necessary for + # the request caching decorator to work properly. + + async def _coro( + _provider: Any, _method: "RPCEndpoint", _params: Any + ) -> "RPCResponse": + return mocked_result + + return await decorator(_coro)(self.w3.provider, method, params) + else: + return mocked_result diff --git a/web3/auto/gethdev.py b/web3/auto/gethdev.py index ebba276991..93134cb77c 100644 --- a/web3/auto/gethdev.py +++ b/web3/auto/gethdev.py @@ -3,11 +3,11 @@ Web3, ) from web3.middleware import ( - geth_poa_middleware, + extradata_to_poa_middleware, ) from web3.providers.ipc import ( get_dev_ipc_path, ) w3 = Web3(IPCProvider(get_dev_ipc_path())) -w3.middleware_onion.inject(geth_poa_middleware, layer=0) +w3.middleware_onion.inject(extradata_to_poa_middleware, layer=0) diff --git a/web3/datastructures.py b/web3/datastructures.py index 0df838e0eb..0b33f95ae3 100644 --- a/web3/datastructures.py +++ b/web3/datastructures.py @@ -172,6 +172,12 @@ def add(self, element: TValue, name: Optional[TKey] = None) -> None: if name is None: name = cast(TKey, element) + try: + # handle unhashable types + name.__hash__() + except TypeError: + name = cast(TKey, repr(name)) + if name in self._queue: if name is element: raise ValueError("You can't add the same un-named instance twice") @@ -206,6 +212,13 @@ def inject( if layer == 0: if name is None: name = cast(TKey, element) + + try: + # handle unhashable types + name.__hash__() + except TypeError: + name = cast(TKey, repr(name)) + self._queue.move_to_end(name, last=False) elif layer == len(self._queue): return @@ -255,13 +268,6 @@ def _replace_with_new_name(self, old: TKey, new: TKey) -> None: self._queue.move_to_end(key) del self._queue[old] - def __iter__(self) -> Iterator[TKey]: - elements = self._queue.values() - if not isinstance(elements, Sequence): - # type ignored b/c elements is set as _OrderedDictValuesView[Any] on 210 - elements = list(elements) # type: ignore - return iter(reversed(elements)) - def __add__(self, other: Any) -> "NamedElementOnion[TKey, TValue]": if not isinstance(other, NamedElementOnion): # you can only combine with another ``NamedElementOnion`` @@ -284,3 +290,27 @@ def __reversed__(self) -> Iterator[TValue]: if not isinstance(elements, Sequence): elements = list(elements) return iter(elements) + + # --- iter and tupleize methods --- # + + def _reversed_middlewares(self) -> Iterator[TValue]: + elements = self._queue.values() + if not isinstance(elements, Sequence): + # type ignored b/c elements is set as _OrderedDictValuesView[Any] on 210 + elements = list(elements) # type: ignore + return reversed(elements) + + def as_tuple_of_middlewares(self) -> Tuple[TValue, ...]: + """ + This helps with type hinting since we return `Iterator[TKey]` type, though it's + actually a `Iterator[TValue]` type, for the `__iter__()` method. This is in + order to satisfy the `Mapping` interface. + """ + return tuple(self._reversed_middlewares()) + + def __iter__(self) -> Iterator[TKey]: + # ``__iter__()`` for a ``Mapping`` returns ``Iterator[TKey]`` but this + # implementation returns ``Iterator[TValue]`` on reversed values (not keys). + # This leads to typing issues, so it's better to use + # ``as_tuple_of_middlewares()`` to achieve the same result. + return iter(self._reversed_middlewares()) # type: ignore diff --git a/web3/main.py b/web3/main.py index 6c9ab7645b..604ea4fcdf 100644 --- a/web3/main.py +++ b/web3/main.py @@ -102,6 +102,7 @@ from web3.manager import ( RequestManager as DefaultRequestManager, ) +from web3.middleware.base import MiddlewareOnion from web3.module import ( Module, ) @@ -119,13 +120,11 @@ from web3.providers.ipc import ( IPCProvider, ) -from web3.providers.async_rpc import ( - AsyncHTTPProvider, -) from web3.providers.persistent import ( PersistentConnectionProvider, ) from web3.providers.rpc import ( + AsyncHTTPProvider, HTTPProvider, ) from web3.providers.websocket import ( @@ -141,8 +140,6 @@ Tracing, ) from web3.types import ( - AsyncMiddlewareOnion, - MiddlewareOnion, Wei, ) @@ -196,12 +193,17 @@ class BaseWeb3: # Managers RequestManager = DefaultRequestManager + manager: DefaultRequestManager # mypy types eth: Union[Eth, AsyncEth] net: Union[Net, AsyncNet] geth: Union[Geth, AsyncGeth] + @property + def middleware_onion(self) -> MiddlewareOnion: + return cast(MiddlewareOnion, self.manager.middleware_onion) + # Encoding and Decoding @staticmethod @wraps(to_bytes) @@ -403,10 +405,6 @@ def __init__( def is_connected(self, show_traceback: bool = False) -> bool: return self.provider.is_connected(show_traceback) - @property - def middleware_onion(self) -> MiddlewareOnion: - return cast(MiddlewareOnion, self.manager.middleware_onion) - @property def provider(self) -> BaseProvider: return cast(BaseProvider, self.manager.provider) @@ -471,10 +469,6 @@ def __init__( async def is_connected(self, show_traceback: bool = False) -> bool: return await self.provider.is_connected(show_traceback) - @property - def middleware_onion(self) -> AsyncMiddlewareOnion: - return cast(AsyncMiddlewareOnion, self.manager.middleware_onion) - @property def provider(self) -> AsyncBaseProvider: return cast(AsyncBaseProvider, self.manager.provider) diff --git a/web3/manager.py b/web3/manager.py index c0adfcca99..49b99955f4 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -37,18 +37,16 @@ MethodUnavailable, ) from web3.middleware import ( - abi_middleware, - async_attrdict_middleware, - async_buffered_gas_estimate_middleware, - async_gas_price_strategy_middleware, - async_name_to_address_middleware, - async_validation_middleware, attrdict_middleware, buffered_gas_estimate_middleware, + ens_name_to_address_middleware, gas_price_strategy_middleware, - name_to_address_middleware, validation_middleware, ) +from web3.middleware.base import ( + Middleware, + MiddlewareOnion, +) from web3.module import ( apply_result_formatters, ) @@ -57,10 +55,6 @@ PersistentConnectionProvider, ) from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareOnion, - Middleware, - MiddlewareOnion, RPCEndpoint, RPCResponse, ) @@ -70,6 +64,9 @@ AsyncWeb3, Web3, ) + from web3.middleware.base import ( # noqa: F401 + Web3Middleware, + ) from web3.providers import ( # noqa: F401 AsyncBaseProvider, BaseProvider, @@ -119,21 +116,15 @@ def apply_null_result_formatters( class RequestManager: - logger = logging.getLogger("web3.RequestManager") + logger = logging.getLogger("web3.manager.RequestManager") - middleware_onion: Union[ - MiddlewareOnion, AsyncMiddlewareOnion, NamedElementOnion[None, None] - ] + middleware_onion: Union["MiddlewareOnion", NamedElementOnion[None, None]] def __init__( self, w3: Union["AsyncWeb3", "Web3"], provider: Optional[Union["BaseProvider", "AsyncBaseProvider"]] = None, - middlewares: Optional[ - Union[ - Sequence[Tuple[Middleware, str]], Sequence[Tuple[AsyncMiddleware, str]] - ] - ] = None, + middlewares: Optional[Sequence[Tuple[Middleware, str]]] = None, ) -> None: self.w3 = w3 @@ -143,11 +134,7 @@ def __init__( self.provider = provider if middlewares is None: - middlewares = ( - self.async_default_middlewares() - if self.provider.is_async - else self.default_middlewares(cast("Web3", w3)) - ) + middlewares = self.get_default_middlewares() self.middleware_onion = NamedElementOnion(middlewares) @@ -169,35 +156,19 @@ def provider(self, provider: Union["BaseProvider", "AsyncBaseProvider"]) -> None self._provider = provider @staticmethod - def default_middlewares(w3: "Web3") -> List[Tuple[Middleware, str]]: + def get_default_middlewares() -> List[Tuple[Middleware, str]]: """ List the default middlewares for the request manager. - Leaving w3 unspecified will prevent the middleware from resolving names. Documentation should remain in sync with these defaults. """ return [ (gas_price_strategy_middleware, "gas_price_strategy"), - (name_to_address_middleware(w3), "name_to_address"), + (ens_name_to_address_middleware, "ens_name_to_address"), (attrdict_middleware, "attrdict"), (validation_middleware, "validation"), - (abi_middleware, "abi"), (buffered_gas_estimate_middleware, "gas_estimate"), ] - @staticmethod - def async_default_middlewares() -> List[Tuple[AsyncMiddleware, str]]: - """ - List the default async middlewares for the request manager. - Documentation should remain in sync with these defaults. - """ - return [ - (async_gas_price_strategy_middleware, "gas_price_strategy"), - (async_name_to_address_middleware, "name_to_address"), - (async_attrdict_middleware, "attrdict"), - (async_validation_middleware, "validation"), - (async_buffered_gas_estimate_middleware, "gas_estimate"), - ] - # # Provider requests and response # @@ -206,7 +177,7 @@ def _make_request( ) -> RPCResponse: provider = cast("BaseProvider", self.provider) request_func = provider.request_func( - cast("Web3", self.w3), cast(MiddlewareOnion, self.middleware_onion) + cast("Web3", self.w3), cast("MiddlewareOnion", self.middleware_onion) ) self.logger.debug(f"Making request. Method: {method}") return request_func(method, params) @@ -216,8 +187,7 @@ async def _coro_make_request( ) -> RPCResponse: provider = cast("AsyncBaseProvider", self.provider) request_func = await provider.request_func( - cast("AsyncWeb3", self.w3), - cast(AsyncMiddlewareOnion, self.middleware_onion), + cast("AsyncWeb3", self.w3), cast("MiddlewareOnion", self.middleware_onion) ) self.logger.debug(f"Making request. Method: {method}") return await request_func(method, params) @@ -345,8 +315,7 @@ async def coro_request( async def ws_send(self, method: RPCEndpoint, params: Any) -> RPCResponse: provider = cast(PersistentConnectionProvider, self._provider) request_func = await provider.request_func( - cast("AsyncWeb3", self.w3), - cast(AsyncMiddlewareOnion, self.middleware_onion), + cast("AsyncWeb3", self.w3), cast("MiddlewareOnion", self.middleware_onion) ) self.logger.debug( "Making request to open websocket connection - " diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index 39697c2720..671aad4753 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -1,74 +1,37 @@ -import functools from typing import ( - Coroutine, TYPE_CHECKING, Any, Callable, + Coroutine, Sequence, ) -from web3.types import ( - AsyncMiddleware, - Middleware, - RPCEndpoint, - RPCResponse, -) - -from .abi import ( - abi_middleware, -) -from .async_cache import ( - _async_simple_cache_middleware as async_simple_cache_middleware, - async_construct_simple_cache_middleware, -) from .attrdict import ( - async_attrdict_middleware, attrdict_middleware, ) +from .base import ( + Middleware, +) from .buffered_gas_estimate import ( - async_buffered_gas_estimate_middleware, buffered_gas_estimate_middleware, ) -from .cache import ( - _latest_block_based_cache_middleware as latest_block_based_cache_middleware, - _simple_cache_middleware as simple_cache_middleware, - _time_based_cache_middleware as time_based_cache_middleware, - construct_latest_block_based_cache_middleware, - construct_simple_cache_middleware, - construct_time_based_cache_middleware, +from .gas_price_strategy import ( + gas_price_strategy_middleware, ) -from .exception_handling import ( - construct_exception_handler_middleware, +from .proof_of_authority import ( + extradata_to_poa_middleware, ) -from .exception_retry_request import ( - async_http_retry_request_middleware, - http_retry_request_middleware, +from .names import ( + ens_name_to_address_middleware, ) from .filter import ( - async_local_filter_middleware, local_filter_middleware, ) -from .fixture import ( - async_construct_error_generator_middleware, - async_construct_result_generator_middleware, - construct_error_generator_middleware, - construct_fixture_middleware, - construct_result_generator_middleware, -) from .formatting import ( construct_formatting_middleware, ) from .gas_price_strategy import ( - async_gas_price_strategy_middleware, - gas_price_strategy_middleware, -) -from .geth_poa import ( - async_geth_poa_middleware, - geth_poa_middleware, -) -from .names import ( - async_name_to_address_middleware, - name_to_address_middleware, + GasPriceStrategyMiddleware, ) from .normalize_request_parameters import ( request_parameter_normalizer, @@ -78,56 +41,61 @@ ) from .signing import ( construct_sign_and_send_raw_middleware, + SignAndSendRawMiddlewareBuilder, ) from .stalecheck import ( - async_make_stalecheck_middleware, + StaleCheckMiddlewareBuilder, make_stalecheck_middleware, ) from .validation import ( - async_validation_middleware, validation_middleware, ) +from ..types import ( + AsyncMakeRequestFn, + MakeRequestFn, +) + if TYPE_CHECKING: - from web3 import AsyncWeb3, Web3 + from web3 import ( + AsyncWeb3, + Web3, + ) + from web3.types import ( + RPCResponse, + ) def combine_middlewares( middlewares: Sequence[Middleware], w3: "Web3", - provider_request_fn: Callable[[RPCEndpoint, Any], Any], -) -> Callable[..., RPCResponse]: + provider_request_fn: MakeRequestFn, +) -> Callable[..., "RPCResponse"]: """ - Returns a callable function which will call the provider.provider_request - function wrapped with all of the middlewares. + Returns a callable function which takes method and params as positional arguments + and passes these args through the request processors, makes the request, and passes + the response through the response processors. """ - return functools.reduce( - lambda request_fn, middleware: middleware(request_fn, w3), - reversed(middlewares), - provider_request_fn, - ) + accumulator_fn = provider_request_fn + for middleware in reversed(middlewares): + # initialize the middleware and wrap the accumulator function down the stack + accumulator_fn = middleware(w3)._wrap_make_request(accumulator_fn) + return accumulator_fn async def async_combine_middlewares( - middlewares: Sequence[AsyncMiddleware], + middlewares: Sequence[Middleware], async_w3: "AsyncWeb3", - provider_request_fn: Callable[[RPCEndpoint, Any], Any], -) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: + provider_request_fn: AsyncMakeRequestFn, +) -> Callable[..., Coroutine[Any, Any, "RPCResponse"]]: """ - Returns a callable function which will call the provider.provider_request - function wrapped with all of the middlewares. + Returns a callable function which takes method and params as positional arguments + and passes these args through the request processors, makes the request, and passes + the response through the response processors. """ accumulator_fn = provider_request_fn for middleware in reversed(middlewares): - accumulator_fn = await construct_middleware( - middleware, accumulator_fn, async_w3 - ) + # initialize the middleware and wrap the accumulator function down the stack + initialized = middleware(async_w3) + accumulator_fn = await initialized._async_wrap_make_request(accumulator_fn) return accumulator_fn - - -async def construct_middleware( - async_middleware: AsyncMiddleware, - fn: Callable[..., RPCResponse], - async_w3: "AsyncWeb3", -) -> Callable[[RPCEndpoint, Any], RPCResponse]: - return await async_middleware(fn, async_w3) diff --git a/web3/middleware/abi.py b/web3/middleware/abi.py deleted file mode 100644 index 97e2063df1..0000000000 --- a/web3/middleware/abi.py +++ /dev/null @@ -1,11 +0,0 @@ -from web3._utils.method_formatters import ( - ABI_REQUEST_FORMATTERS, -) - -from .formatting import ( - construct_formatting_middleware, -) - -abi_middleware = construct_formatting_middleware( - request_formatters=ABI_REQUEST_FORMATTERS -) diff --git a/web3/middleware/async_cache.py b/web3/middleware/async_cache.py deleted file mode 100644 index ab354cdd58..0000000000 --- a/web3/middleware/async_cache.py +++ /dev/null @@ -1,99 +0,0 @@ -from concurrent.futures import ( - ThreadPoolExecutor, -) -import threading -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Collection, -) - -from web3._utils.async_caching import ( - async_lock, -) -from web3._utils.caching import ( - generate_cache_key, -) -from web3.middleware.cache import ( - SIMPLE_CACHE_RPC_WHITELIST, - _should_cache_response, -) -from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareCoroutine, - Middleware, - RPCEndpoint, - RPCResponse, -) -from web3.utils.caching import ( - SimpleCache, -) - -if TYPE_CHECKING: - from web3 import ( # noqa: F401 - AsyncWeb3, - Web3, - ) - -_async_request_thread_pool = ThreadPoolExecutor() - - -async def async_construct_simple_cache_middleware( - cache: SimpleCache = None, - rpc_whitelist: Collection[RPCEndpoint] = SIMPLE_CACHE_RPC_WHITELIST, - should_cache_fn: Callable[ - [RPCEndpoint, Any, RPCResponse], bool - ] = _should_cache_response, -) -> AsyncMiddleware: - """ - Constructs a middleware which caches responses based on the request - ``method`` and ``params`` - - :param cache: A ``SimpleCache`` class. - :param rpc_whitelist: A set of RPC methods which may have their responses cached. - :param should_cache_fn: A callable which accepts ``method`` ``params`` and - ``response`` and returns a boolean as to whether the response should be - cached. - """ - - async def async_simple_cache_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], _async_w3: "AsyncWeb3" - ) -> AsyncMiddlewareCoroutine: - lock = threading.Lock() - - # It's not imperative that we define ``_cache`` here rather than in - # ``async_construct_simple_cache_middleware``. Due to the nature of async, - # construction is awaited and doesn't happen at import. This means separate - # instances would still get unique caches. However, to keep the code consistent - # with the synchronous version, and provide less ambiguity, we define it - # similarly to the synchronous version here. - _cache = cache if cache else SimpleCache(256) - - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in rpc_whitelist: - cache_key = generate_cache_key( - f"{threading.get_ident()}:{(method, params)}" - ) - cached_request = _cache.get_cache_entry(cache_key) - if cached_request is not None: - return cached_request - - response = await make_request(method, params) - if should_cache_fn(method, params, response): - async with async_lock(_async_request_thread_pool, lock): - _cache.cache(cache_key, response) - return response - else: - return await make_request(method, params) - - return middleware - - return async_simple_cache_middleware - - -async def _async_simple_cache_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" -) -> Middleware: - middleware = await async_construct_simple_cache_middleware() - return await middleware(make_request, async_w3) diff --git a/web3/middleware/attrdict.py b/web3/middleware/attrdict.py index ea0c152d5d..20051a6c33 100644 --- a/web3/middleware/attrdict.py +++ b/web3/middleware/attrdict.py @@ -1,8 +1,9 @@ +from abc import ( + ABC, +) from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, cast, ) @@ -13,10 +14,8 @@ from web3.datastructures import ( AttributeDict, ) -from web3.types import ( - AsyncMiddlewareCoroutine, - RPCEndpoint, - RPCResponse, +from web3.middleware.base import ( + Web3Middleware, ) if TYPE_CHECKING: @@ -27,11 +26,31 @@ from web3.providers import ( # noqa: F401 PersistentConnectionProvider, ) + from web3.types import ( # noqa: F401 + RPCEndpoint, + RPCResponse, + ) -def attrdict_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], _w3: "Web3" -) -> Callable[[RPCEndpoint, Any], RPCResponse]: +def _handle_async_response(response: "RPCResponse") -> "RPCResponse": + if "result" in response: + return assoc(response, "result", AttributeDict.recursive(response["result"])) + elif "params" in response and "result" in response["params"]: + # this is a subscription response + return assoc( + response, + "params", + assoc( + response["params"], + "result", + AttributeDict.recursive(response["params"]["result"]), + ), + ) + else: + return response + + +class AttributeDictMiddleware(Web3Middleware, ABC): """ Converts any result which is a dictionary into an `AttributeDict`. @@ -39,9 +58,7 @@ def attrdict_middleware( (e.g. my_attribute_dict.property1) will not preserve typing. """ - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - response = make_request(method, params) - + def response_processor(self, method: "RPCEndpoint", response: "RPCResponse") -> Any: if "result" in response: return assoc( response, "result", AttributeDict.recursive(response["result"]) @@ -49,27 +66,14 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: else: return response - return middleware - - -# --- async --- # - + # -- async -- # -async def async_attrdict_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" -) -> AsyncMiddlewareCoroutine: - """ - Converts any result which is a dictionary into an `AttributeDict`. - - Note: Accessing `AttributeDict` properties via attribute - (e.g. my_attribute_dict.property1) will not preserve typing. - """ - - async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: - response = await make_request(method, params) - if async_w3.provider.has_persistent_connection: + async def async_response_processor( + self, method: "RPCEndpoint", response: "RPCResponse" + ) -> Any: + if self._w3.provider.has_persistent_connection: # asynchronous response processing - provider = cast("PersistentConnectionProvider", async_w3.provider) + provider = cast("PersistentConnectionProvider", self._w3.provider) provider._request_processor.append_middleware_response_processor( response, _handle_async_response ) @@ -77,22 +81,5 @@ async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: else: return _handle_async_response(response) - return middleware - -def _handle_async_response(response: RPCResponse) -> RPCResponse: - if "result" in response: - return assoc(response, "result", AttributeDict.recursive(response["result"])) - elif "params" in response and "result" in response["params"]: - # this is a subscription response - return assoc( - response, - "params", - assoc( - response["params"], - "result", - AttributeDict.recursive(response["params"]["result"]), - ), - ) - else: - return response +attrdict_middleware = AttributeDictMiddleware diff --git a/web3/middleware/base.py b/web3/middleware/base.py new file mode 100644 index 0000000000..e4c2a044cc --- /dev/null +++ b/web3/middleware/base.py @@ -0,0 +1,129 @@ +from abc import ( + abstractmethod, +) +from typing import ( + TYPE_CHECKING, + Any, + Type, + Union, +) + +from web3.datastructures import ( + NamedElementOnion, +) + +if TYPE_CHECKING: + from web3 import ( # noqa: F401 + AsyncWeb3, + Web3, + ) + from web3.types import ( # noqa: F401 + AsyncMakeRequestFn, + MakeRequestFn, + RPCEndpoint, + RPCResponse, + ) + + +class Web3Middleware: + """ + Base class for web3.py middleware. This class is not meant to be used directly, + but instead inherited from. + """ + + _w3: Union["AsyncWeb3", "Web3"] + + def __init__(self, w3: Union["AsyncWeb3", "Web3"]) -> None: + self._w3 = w3 + + # -- sync -- # + + def _wrap_make_request(self, make_request: "MakeRequestFn") -> "MakeRequestFn": + def middleware(method: "RPCEndpoint", params: Any) -> "RPCResponse": + method, params = self.request_processor(method, params) + return self.response_processor(method, make_request(method, params)) + + return middleware + + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + return method, params + + def response_processor(self, method: "RPCEndpoint", response: "RPCResponse") -> Any: + return response + + # -- async -- # + + async def _async_wrap_make_request( + self, make_request: "AsyncMakeRequestFn" + ) -> "AsyncMakeRequestFn": + async def middleware(method: "RPCEndpoint", params: Any) -> "RPCResponse": + method, params = await self.async_request_processor(method, params) + return await self.async_response_processor( + method, + await make_request(method, params), + ) + + return middleware + + async def async_request_processor( + self, + method: "RPCEndpoint", + params: Any, + ) -> Any: + return method, params + + async def async_response_processor( + self, + method: "RPCEndpoint", + response: "RPCResponse", + ) -> Any: + return response + + +class Web3MiddlewareBuilder(Web3Middleware): + @staticmethod + @abstractmethod + def build( + w3: Union["AsyncWeb3", "Web3"], + *args: Any, + **kwargs: Any, + ) -> Web3Middleware: + """ + Implementation should initialize the middleware class that implements it, + load it with any of the necessary properties that it needs for processing, + and curry for the ``w3`` argument since it isn't initially present when building + the middleware. + + example implementation: + + ```py + class MyMiddleware(Web3BuilderMiddleware): + internal_property: str = None + + @staticmethod + @curry + def builder(user_provided_argument, w3): + middleware = MyMiddleware(w3) + middleware.internal_property = user_provided_argument + return middleware + + def request_processor(self, method, params): + ... + + def response_processor(self, method, response): + ... + + construct_my_middleware = MyMiddleware.builder + + w3 = Web3(provider) + my_middleware = construct_my_middleware("my argument") + w3.middleware_onion.inject(my_middleware, layer=0) + ``` + """ + raise NotImplementedError("Must be implemented by subclasses") + + +# --- type definitions --- # + +Middleware = Type[Web3Middleware] +MiddlewareOnion = NamedElementOnion[str, Middleware] diff --git a/web3/middleware/buffered_gas_estimate.py b/web3/middleware/buffered_gas_estimate.py index 60f1cc739b..bb9ab509e6 100644 --- a/web3/middleware/buffered_gas_estimate.py +++ b/web3/middleware/buffered_gas_estimate.py @@ -1,7 +1,7 @@ from typing import ( TYPE_CHECKING, Any, - Callable, + cast, ) from eth_utils.toolz import ( @@ -14,10 +14,11 @@ from web3._utils.transactions import ( get_buffered_gas_estimate, ) +from web3.middleware.base import ( + Web3Middleware, +) from web3.types import ( - AsyncMiddlewareCoroutine, RPCEndpoint, - RPCResponse, ) if TYPE_CHECKING: @@ -27,34 +28,35 @@ ) -def buffered_gas_estimate_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" -) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: +class BufferedGasEstimateMiddleware(Web3Middleware): + """ + Includes a gas estimate for all transactions that do not already have a gas value. + """ + + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method == "eth_sendTransaction": transaction = params[0] if "gas" not in transaction: transaction = assoc( transaction, "gas", - hex(get_buffered_gas_estimate(w3, transaction)), + hex(get_buffered_gas_estimate(cast("Web3", self._w3), transaction)), ) - return make_request(method, [transaction]) - return make_request(method, params) - - return middleware + params = (transaction,) + return method, params + # -- async -- # -async def async_buffered_gas_estimate_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "AsyncWeb3" -) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method == "eth_sendTransaction": transaction = params[0] if "gas" not in transaction: - gas_estimate = await async_get_buffered_gas_estimate(w3, transaction) + gas_estimate = await async_get_buffered_gas_estimate( + cast("AsyncWeb3", self._w3), transaction + ) transaction = assoc(transaction, "gas", hex(gas_estimate)) - return await make_request(method, [transaction]) - return await make_request(method, params) + params = (transaction,) + return method, params + - return middleware +buffered_gas_estimate_middleware = BufferedGasEstimateMiddleware diff --git a/web3/middleware/cache.py b/web3/middleware/cache.py deleted file mode 100644 index 189cf706cb..0000000000 --- a/web3/middleware/cache.py +++ /dev/null @@ -1,374 +0,0 @@ -import functools -import threading -import time -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Collection, - Dict, - Set, - cast, -) - -from eth_utils import ( - is_list_like, -) -import lru - -from web3._utils.caching import ( - generate_cache_key, -) -from web3._utils.compat import ( - Literal, - TypedDict, -) -from web3.types import ( - BlockData, - BlockNumber, - Middleware, - RPCEndpoint, - RPCResponse, -) -from web3.utils.caching import ( - SimpleCache, -) - -if TYPE_CHECKING: - from web3 import Web3 # noqa: F401 - -SIMPLE_CACHE_RPC_WHITELIST = cast( - Set[RPCEndpoint], - ( - "web3_clientVersion", - "net_version", - "eth_getBlockTransactionCountByHash", - "eth_getUncleCountByBlockHash", - "eth_getBlockByHash", - "eth_getTransactionByHash", - "eth_getTransactionByBlockHashAndIndex", - "eth_getRawTransactionByHash", - "eth_getUncleByBlockHashAndIndex", - "eth_chainId", - ), -) - - -def _should_cache_response( - _method: RPCEndpoint, _params: Any, response: RPCResponse -) -> bool: - return ( - "error" not in response - and "result" in response - and response["result"] is not None - ) - - -def construct_simple_cache_middleware( - cache: SimpleCache = None, - rpc_whitelist: Collection[RPCEndpoint] = None, - should_cache_fn: Callable[ - [RPCEndpoint, Any, RPCResponse], bool - ] = _should_cache_response, -) -> Middleware: - """ - Constructs a middleware which caches responses based on the request - ``method`` and ``params`` - - :param cache: A ``SimpleCache`` class. - :param rpc_whitelist: A set of RPC methods which may have their responses cached. - :param should_cache_fn: A callable which accepts ``method`` ``params`` and - ``response`` and returns a boolean as to whether the response should be - cached. - """ - if rpc_whitelist is None: - rpc_whitelist = SIMPLE_CACHE_RPC_WHITELIST - - def simple_cache_middleware( - make_request: Callable[[RPCEndpoint, Any], RPCResponse], _w3: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - lock = threading.Lock() - - # Setting the cache here, rather than in ``construct_simple_cache_middleware``, - # ensures that each instance of the middleware has its own cache. This is - # important for compatibility with multiple ``Web3`` instances. - _cache = cache if cache else SimpleCache(256) - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in rpc_whitelist: - cache_key = generate_cache_key( - f"{threading.get_ident()}:{(method, params)}" - ) - cached_request = _cache.get_cache_entry(cache_key) - if cached_request is not None: - return cached_request - - response = make_request(method, params) - if should_cache_fn(method, params, response): - if lock.acquire(blocking=False): - try: - _cache.cache(cache_key, response) - finally: - lock.release() - return response - else: - return make_request(method, params) - - return middleware - - return simple_cache_middleware - - -_simple_cache_middleware = construct_simple_cache_middleware() - - -TIME_BASED_CACHE_RPC_WHITELIST = cast( - Set[RPCEndpoint], - { - "eth_coinbase", - "eth_accounts", - }, -) - - -def construct_time_based_cache_middleware( - cache_class: Callable[..., Dict[Any, Any]], - cache_expire_seconds: int = 15, - rpc_whitelist: Collection[RPCEndpoint] = TIME_BASED_CACHE_RPC_WHITELIST, - should_cache_fn: Callable[ - [RPCEndpoint, Any, RPCResponse], bool - ] = _should_cache_response, -) -> Middleware: - """ - Constructs a middleware which caches responses based on the request - ``method`` and ``params`` for a maximum amount of time as specified - - :param cache_class: Any dictionary-like object - :param cache_expire_seconds: The number of seconds an item may be cached - before it should expire. - :param rpc_whitelist: A set of RPC methods which may have their responses cached. - :param should_cache_fn: A callable which accepts ``method`` ``params`` and - ``response`` and returns a boolean as to whether the response should be - cached. - """ - - def time_based_cache_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - cache = cache_class() - lock = threading.Lock() - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - lock_acquired = ( - lock.acquire(blocking=False) if method in rpc_whitelist else False - ) - - try: - if lock_acquired and method in rpc_whitelist: - cache_key = generate_cache_key((method, params)) - if cache_key in cache: - # check that the cached response is not expired. - cached_at, cached_response = cache[cache_key] - cached_for = time.time() - cached_at - - if cached_for <= cache_expire_seconds: - return cached_response - else: - del cache[cache_key] - - # cache either missed or expired so make the request. - response = make_request(method, params) - - if should_cache_fn(method, params, response): - cache[cache_key] = (time.time(), response) - - return response - else: - return make_request(method, params) - finally: - if lock_acquired: - lock.release() - - return middleware - - return time_based_cache_middleware - - -_time_based_cache_middleware = construct_time_based_cache_middleware( - cache_class=functools.partial(lru.LRU, 256), -) - - -BLOCK_NUMBER_RPC_WHITELIST = cast( - Set[RPCEndpoint], - { - "eth_gasPrice", - "eth_blockNumber", - "eth_getBalance", - "eth_getStorageAt", - "eth_getTransactionCount", - "eth_getBlockTransactionCountByNumber", - "eth_getUncleCountByBlockNumber", - "eth_getCode", - "eth_call", - "eth_createAccessList", - "eth_estimateGas", - "eth_getBlockByNumber", - "eth_getTransactionByBlockNumberAndIndex", - "eth_getTransactionReceipt", - "eth_getUncleByBlockNumberAndIndex", - "eth_getLogs", - }, -) - -AVG_BLOCK_TIME_KEY: Literal["avg_block_time"] = "avg_block_time" -AVG_BLOCK_SAMPLE_SIZE_KEY: Literal["avg_block_sample_size"] = "avg_block_sample_size" -AVG_BLOCK_TIME_UPDATED_AT_KEY: Literal[ - "avg_block_time_updated_at" -] = "avg_block_time_updated_at" - - -def _is_latest_block_number_request(method: RPCEndpoint, params: Any) -> bool: - if method != "eth_getBlockByNumber": - return False - elif is_list_like(params) and tuple(params[:1]) == ("latest",): - return True - return False - - -BlockInfoCache = TypedDict( - "BlockInfoCache", - { - "avg_block_time": float, - "avg_block_sample_size": int, - "avg_block_time_updated_at": float, - "latest_block": BlockData, - }, - total=False, -) - - -def construct_latest_block_based_cache_middleware( - cache_class: Callable[..., Dict[Any, Any]], - rpc_whitelist: Collection[RPCEndpoint] = BLOCK_NUMBER_RPC_WHITELIST, - average_block_time_sample_size: int = 240, - default_average_block_time: int = 15, - should_cache_fn: Callable[ - [RPCEndpoint, Any, RPCResponse], bool - ] = _should_cache_response, -) -> Middleware: - """ - Constructs a middleware which caches responses based on the request - ``method``, ``params``, and the current latest block hash. - - :param cache_class: Any dictionary-like object - :param rpc_whitelist: A set of RPC methods which may have their responses cached. - :param average_block_time_sample_size: number of blocks to look back when computing - average block time - :param default_average_block_time: estimated number of seconds per block - :param should_cache_fn: A callable which accepts ``method`` ``params`` and - ``response`` and returns a boolean as to whether the response should be - cached. - - .. note:: - This middleware avoids re-fetching the current latest block for each - request by tracking the current average block time and only requesting - a new block when the last seen latest block is older than the average - block time. - """ - - def latest_block_based_cache_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - cache = cache_class() - block_info: BlockInfoCache = {} - - def _update_block_info_cache() -> None: - avg_block_time = block_info.get( - AVG_BLOCK_TIME_KEY, default_average_block_time - ) - avg_block_sample_size = block_info.get(AVG_BLOCK_SAMPLE_SIZE_KEY, 0) - avg_block_time_updated_at = block_info.get(AVG_BLOCK_TIME_UPDATED_AT_KEY, 0) - - # compute age as counted by number of blocks since the avg_block_time - if avg_block_time == 0: - avg_block_time_age_in_blocks: float = avg_block_sample_size - else: - avg_block_time_age_in_blocks = ( - time.time() - avg_block_time_updated_at - ) / avg_block_time - - if avg_block_time_age_in_blocks >= avg_block_sample_size: - # If the length of time since the average block time as - # measured by blocks is greater than or equal to the number of - # blocks sampled then we need to recompute the average block - # time. - latest_block = w3.eth.get_block("latest") - ancestor_block_number = BlockNumber( - max( - 0, - latest_block["number"] - average_block_time_sample_size, - ) - ) - ancestor_block = w3.eth.get_block(ancestor_block_number) - sample_size = latest_block["number"] - ancestor_block_number - - block_info[AVG_BLOCK_SAMPLE_SIZE_KEY] = sample_size - if sample_size != 0: - block_info[AVG_BLOCK_TIME_KEY] = ( - latest_block["timestamp"] - ancestor_block["timestamp"] - ) / sample_size - else: - block_info[AVG_BLOCK_TIME_KEY] = avg_block_time - block_info[AVG_BLOCK_TIME_UPDATED_AT_KEY] = time.time() - - if "latest_block" in block_info: - latest_block = block_info["latest_block"] - time_since_latest_block = time.time() - latest_block["timestamp"] - - # latest block is too old so update cache - if time_since_latest_block > avg_block_time: - block_info["latest_block"] = w3.eth.get_block("latest") - else: - # latest block has not been fetched so we fetch it. - block_info["latest_block"] = w3.eth.get_block("latest") - - lock = threading.Lock() - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - lock_acquired = ( - lock.acquire(blocking=False) if method in rpc_whitelist else False - ) - - try: - should_try_cache = ( - lock_acquired - and method in rpc_whitelist - and not _is_latest_block_number_request(method, params) - ) - if should_try_cache: - _update_block_info_cache() - latest_block_hash = block_info["latest_block"]["hash"] - cache_key = generate_cache_key((latest_block_hash, method, params)) - if cache_key in cache: - return cache[cache_key] - - response = make_request(method, params) - if should_cache_fn(method, params, response): - cache[cache_key] = response - return response - else: - return make_request(method, params) - finally: - if lock_acquired: - lock.release() - - return middleware - - return latest_block_based_cache_middleware - - -_latest_block_based_cache_middleware = construct_latest_block_based_cache_middleware( - cache_class=functools.partial(lru.LRU, 256), - rpc_whitelist=BLOCK_NUMBER_RPC_WHITELIST, -) diff --git a/web3/middleware/exception_handling.py b/web3/middleware/exception_handling.py deleted file mode 100644 index d264c7eaf7..0000000000 --- a/web3/middleware/exception_handling.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Optional, - Tuple, - Type, -) - -from eth_utils.toolz import ( - excepts, -) - -from web3.types import ( - Middleware, - RPCEndpoint, - RPCResponse, -) - -if TYPE_CHECKING: - from web3 import Web3 # noqa: F401 - - -def construct_exception_handler_middleware( - method_handlers: Optional[ - Dict[RPCEndpoint, Tuple[Type[BaseException], Callable[..., None]]] - ] = None -) -> Middleware: - if method_handlers is None: - method_handlers = {} - - def exception_handler_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in method_handlers: - exc_type, handler = method_handlers[method] - return excepts( - exc_type, - make_request, - handler, - )(method, params) - else: - return make_request(method, params) - - return middleware - - return exception_handler_middleware diff --git a/web3/middleware/exception_retry_request.py b/web3/middleware/exception_retry_request.py deleted file mode 100644 index aa216c4e28..0000000000 --- a/web3/middleware/exception_retry_request.py +++ /dev/null @@ -1,188 +0,0 @@ -import asyncio -import time -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Collection, - List, - Optional, - Type, -) - -import aiohttp -from requests.exceptions import ( - ConnectionError, - HTTPError, - Timeout, - TooManyRedirects, -) - -from web3.types import ( - AsyncMiddlewareCoroutine, - RPCEndpoint, - RPCResponse, -) - -if TYPE_CHECKING: - from web3 import ( # noqa: F401 - AsyncWeb3, - Web3, - ) - -DEFAULT_ALLOWLIST = [ - "admin", - "miner", - "net", - "txpool", - "testing", - "evm", - "eth_protocolVersion", - "eth_syncing", - "eth_coinbase", - "eth_mining", - "eth_hashrate", - "eth_chainId", - "eth_gasPrice", - "eth_accounts", - "eth_blockNumber", - "eth_getBalance", - "eth_getStorageAt", - "eth_getProof", - "eth_getCode", - "eth_getBlockByNumber", - "eth_getBlockByHash", - "eth_getBlockTransactionCountByNumber", - "eth_getBlockTransactionCountByHash", - "eth_getUncleCountByBlockNumber", - "eth_getUncleCountByBlockHash", - "eth_getTransactionByHash", - "eth_getTransactionByBlockHashAndIndex", - "eth_getTransactionByBlockNumberAndIndex", - "eth_getTransactionReceipt", - "eth_getTransactionCount", - "eth_getRawTransactionByHash", - "eth_call", - "eth_createAccessList", - "eth_estimateGas", - "eth_maxPriorityFeePerGas", - "eth_newBlockFilter", - "eth_newPendingTransactionFilter", - "eth_newFilter", - "eth_getFilterChanges", - "eth_getFilterLogs", - "eth_getLogs", - "eth_uninstallFilter", - "eth_getCompilers", - "eth_getWork", - "eth_sign", - "eth_signTypedData", - "eth_sendRawTransaction", - "personal_importRawKey", - "personal_newAccount", - "personal_listAccounts", - "personal_listWallets", - "personal_lockAccount", - "personal_unlockAccount", - "personal_ecRecover", - "personal_sign", - "personal_signTypedData", -] - - -def check_if_retry_on_failure( - method: str, allow_list: Optional[List[str]] = None -) -> bool: - if allow_list is None: - allow_list = DEFAULT_ALLOWLIST - - root = method.split("_")[0] - if root in allow_list: - return True - elif method in allow_list: - return True - else: - return False - - -def exception_retry_middleware( - make_request: Callable[[RPCEndpoint, Any], RPCResponse], - _w3: "Web3", - errors: Collection[Type[BaseException]], - retries: int = 5, - backoff_factor: float = 0.3, - allow_list: Optional[List[str]] = None, -) -> Callable[[RPCEndpoint, Any], RPCResponse]: - """ - Creates middleware that retries failed HTTP requests. Is a default - middleware for HTTPProvider. - """ - - def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: - if check_if_retry_on_failure(method, allow_list): - for i in range(retries): - try: - return make_request(method, params) - except tuple(errors): - if i < retries - 1: - time.sleep(backoff_factor) - continue - else: - raise - return None - else: - return make_request(method, params) - - return middleware - - -def http_retry_request_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" -) -> Callable[[RPCEndpoint, Any], Any]: - return exception_retry_middleware( - make_request, w3, (ConnectionError, HTTPError, Timeout, TooManyRedirects) - ) - - -# -- async -- # - - -async def async_exception_retry_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], - _async_w3: "AsyncWeb3", - errors: Collection[Type[BaseException]], - retries: int = 5, - backoff_factor: float = 0.3, - allow_list: Optional[List[str]] = None, -) -> AsyncMiddlewareCoroutine: - """ - Creates middleware that retries failed HTTP requests. - Is a default middleware for AsyncHTTPProvider. - """ - - async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: - if check_if_retry_on_failure(method, allow_list): - for i in range(retries): - try: - return await make_request(method, params) - except tuple(errors): - if i < retries - 1: - await asyncio.sleep(backoff_factor) - continue - else: - raise - return None - else: - return await make_request(method, params) - - return middleware - - -async def async_http_retry_request_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" -) -> Callable[[RPCEndpoint, Any], Any]: - return await async_exception_retry_middleware( - make_request, - async_w3, - (TimeoutError, aiohttp.ClientError), - ) diff --git a/web3/middleware/filter.py b/web3/middleware/filter.py index 1350a79f8c..42afa5dc60 100644 --- a/web3/middleware/filter.py +++ b/web3/middleware/filter.py @@ -5,7 +5,6 @@ Any, AsyncIterable, AsyncIterator, - Callable, Dict, Generator, Iterable, @@ -43,18 +42,27 @@ from web3._utils.rpc_abi import ( RPC, ) +from web3.middleware.base import ( + Web3Middleware, +) from web3.types import ( - Coroutine, + AsyncMakeRequestFn, FilterParams, LatestBlockParam, LogReceipt, - RPCEndpoint, - RPCResponse, + MakeRequestFn, _Hash32, ) if TYPE_CHECKING: - from web3 import Web3 # noqa: F401 + from web3 import ( # noqa: F401 + AsyncWeb3, + Web3, + ) + from web3.types import ( # noqa: F401 + RPCEndpoint, + RPCResponse, + ) if "WEB3_MAX_BLOCK_REQUEST" in os.environ: MAX_BLOCK_REQUEST = to_int(text=os.environ["WEB3_MAX_BLOCK_REQUEST"]) @@ -338,57 +346,11 @@ def block_hashes_in_range( yield getattr(w3.eth.get_block(BlockNumber(block_number)), "hash", None) -def local_filter_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" -) -> Callable[[RPCEndpoint, Any], RPCResponse]: - filters = {} - filter_id_counter = map(to_hex, itertools.count()) - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in NEW_FILTER_METHODS: - filter_id = next(filter_id_counter) - - _filter: Union[RequestLogs, RequestBlocks] - if method == RPC.eth_newFilter: - _filter = RequestLogs( - w3, **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) - ) - - elif method == RPC.eth_newBlockFilter: - _filter = RequestBlocks(w3) - - else: - raise NotImplementedError(method) - - filters[filter_id] = _filter - return {"result": filter_id} - - elif method in FILTER_CHANGES_METHODS: - filter_id = params[0] - # Pass through to filters not created by middleware - if filter_id not in filters: - return make_request(method, params) - _filter = filters[filter_id] - if method == RPC.eth_getFilterChanges: - return {"result": next(_filter.filter_changes)} - - elif method == RPC.eth_getFilterLogs: - # type ignored b/c logic prevents RequestBlocks which - # doesn't implement get_logs - return {"result": _filter.get_logs()} # type: ignore - else: - raise NotImplementedError(method) - else: - return make_request(method, params) - - return middleware - - # --- async --- # async def async_iter_latest_block( - w3: "Web3", to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None + w3: "AsyncWeb3", to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None ) -> AsyncIterable[BlockNumber]: """Returns a generator that dispenses the latest block, if any new blocks have been mined since last iteration. @@ -411,9 +373,9 @@ async def async_iter_latest_block( is_bounded_range = to_block is not None and to_block != "latest" while True: - latest_block = await w3.eth.block_number # type: ignore + latest_block = await w3.eth.block_number # type ignored b/c is_bounded_range prevents unsupported comparison - if is_bounded_range and latest_block > to_block: + if is_bounded_range and latest_block > cast(int, to_block): yield None # No new blocks since last iteration. if _last is not None and _last == latest_block: @@ -424,7 +386,7 @@ async def async_iter_latest_block( async def async_iter_latest_block_ranges( - w3: "Web3", + w3: "AsyncWeb3", from_block: BlockNumber, to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None, ) -> AsyncIterable[Tuple[Optional[BlockNumber], Optional[BlockNumber]]]: @@ -454,7 +416,7 @@ async def async_iter_latest_block_ranges( async def async_get_logs_multipart( - w3: "Web3", + w3: "AsyncWeb3", start_block: BlockNumber, stop_block: BlockNumber, address: Union[Address, ChecksumAddress, List[Union[Address, ChecksumAddress]]], @@ -477,7 +439,7 @@ async def async_get_logs_multipart( params_with_none_dropped = cast( FilterParams, drop_items_with_none_value(params) ) - next_logs = await w3.eth.get_logs(params_with_none_dropped) # type: ignore + next_logs = await w3.eth.get_logs(params_with_none_dropped) yield next_logs @@ -486,7 +448,7 @@ class AsyncRequestLogs: def __init__( self, - w3: "Web3", + w3: "AsyncWeb3", from_block: Optional[Union[BlockNumber, LatestBlockParam]] = None, to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None, address: Optional[ @@ -504,7 +466,7 @@ def __init__( def __await__(self) -> Generator[Any, None, "AsyncRequestLogs"]: async def closure() -> "AsyncRequestLogs": if self._from_block_arg is None or self._from_block_arg == "latest": - self.block_number = await self.w3.eth.block_number # type: ignore + self.block_number = await self.w3.eth.block_number self._from_block = BlockNumber(self.block_number + 1) elif is_string(self._from_block_arg) and is_hex(self._from_block_arg): self._from_block = BlockNumber( @@ -524,7 +486,7 @@ async def from_block(self) -> BlockNumber: @property async def to_block(self) -> BlockNumber: if self._to_block is None or self._to_block == "latest": - to_block = await self.w3.eth.block_number # type: ignore + to_block = await self.w3.eth.block_number elif is_string(self._to_block) and is_hex(self._to_block): to_block = BlockNumber(hex_to_integer(cast(HexStr, self._to_block))) else: @@ -572,12 +534,12 @@ async def get_logs(self) -> List[LogReceipt]: class AsyncRequestBlocks: - def __init__(self, w3: "Web3") -> None: + def __init__(self, w3: "AsyncWeb3") -> None: self.w3 = w3 def __await__(self) -> Generator[Any, None, "AsyncRequestBlocks"]: async def closure() -> "AsyncRequestBlocks": - self.block_number = await self.w3.eth.block_number # type: ignore + self.block_number = await self.w3.eth.block_number self.start_block = BlockNumber(self.block_number + 1) return self @@ -597,7 +559,7 @@ async def get_filter_changes(self) -> AsyncIterator[List[Hash32]]: async def async_block_hashes_in_range( - w3: "Web3", block_range: Tuple[BlockNumber, BlockNumber] + w3: "AsyncWeb3", block_range: Tuple[BlockNumber, BlockNumber] ) -> List[Union[None, Hash32]]: from_block, to_block = block_range if from_block is None or to_block is None: @@ -605,56 +567,115 @@ async def async_block_hashes_in_range( block_hashes = [] for block_number in range(from_block, to_block + 1): - w3_get_block = await w3.eth.get_block(BlockNumber(block_number)) # type: ignore + w3_get_block = await w3.eth.get_block(BlockNumber(block_number)) block_hashes.append(getattr(w3_get_block, "hash", None)) return block_hashes -async def async_local_filter_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" -) -> Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]: - filters = {} - filter_id_counter = map(to_hex, itertools.count()) +# -- middleware -- # - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in NEW_FILTER_METHODS: - filter_id = next(filter_id_counter) +SyncFilter = Union[RequestLogs, RequestBlocks] +AsyncFilter = Union[AsyncRequestLogs, AsyncRequestBlocks] - _filter: Union[AsyncRequestLogs, AsyncRequestBlocks] - if method == RPC.eth_newFilter: - _filter = await AsyncRequestLogs( - w3, **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) - ) - elif method == RPC.eth_newBlockFilter: - _filter = await AsyncRequestBlocks(w3) +class LocalFilterMiddleware(Web3Middleware): + def __init__(self, w3: Union["Web3", "AsyncWeb3"]): + self.filters: Dict[str, SyncFilter] = {} + self.async_filters: Dict[str, AsyncFilter] = {} + self.filter_id_counter = itertools.count() + super().__init__(w3) + def _wrap_make_request(self, make_request: MakeRequestFn) -> MakeRequestFn: + def middleware(method: "RPCEndpoint", params: Any) -> "RPCResponse": + if method in NEW_FILTER_METHODS: + _filter: SyncFilter + filter_id = to_hex(next(self.filter_id_counter)) + + if method == RPC.eth_newFilter: + _filter = RequestLogs( + cast("Web3", self._w3), + **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) + ) + + elif method == RPC.eth_newBlockFilter: + _filter = RequestBlocks(cast("Web3", self._w3)) + + else: + raise NotImplementedError(method) + + self.filters[filter_id] = _filter + return {"result": filter_id} + + elif method in FILTER_CHANGES_METHODS: + _filter_id = params[0] + + # Pass through to filters not created by middleware + if _filter_id not in self.filters: + return make_request(method, params) + + _filter = self.filters[_filter_id] + if method == RPC.eth_getFilterChanges: + return {"result": next(_filter.filter_changes)} + + elif method == RPC.eth_getFilterLogs: + # type ignored b/c logic prevents RequestBlocks which + # doesn't implement get_logs + return {"result": _filter.get_logs()} # type: ignore + else: + raise NotImplementedError(method) else: - raise NotImplementedError(method) + return make_request(method, params) - filters[filter_id] = _filter - return {"result": filter_id} + return middleware - elif method in FILTER_CHANGES_METHODS: - filter_id = params[0] - # Pass through to filters not created by middleware - if filter_id not in filters: - return await make_request(method, params) - _filter = filters[filter_id] + # -- async -- # - if method == RPC.eth_getFilterChanges: - changes = await _filter.filter_changes.__anext__() - return {"result": changes} + async def _async_wrap_make_request( + self, + make_request: AsyncMakeRequestFn, + ) -> AsyncMakeRequestFn: + async def middleware(method: "RPCEndpoint", params: Any) -> "RPCResponse": + if method in NEW_FILTER_METHODS: + _filter: AsyncFilter + filter_id = to_hex(next(self.filter_id_counter)) + + if method == RPC.eth_newFilter: + _filter = await AsyncRequestLogs( + cast("AsyncWeb3", self._w3), + **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) + ) + + elif method == RPC.eth_newBlockFilter: + _filter = await AsyncRequestBlocks(cast("AsyncWeb3", self._w3)) + + else: + raise NotImplementedError(method) + + self.async_filters[filter_id] = _filter + return {"result": filter_id} + + elif method in FILTER_CHANGES_METHODS: + _filter_id = params[0] + + # Pass through to filters not created by middleware + if _filter_id not in self.async_filters: + return await make_request(method, params) - elif method == RPC.eth_getFilterLogs: - # type ignored b/c logic prevents RequestBlocks which - # doesn't implement get_logs - logs = await _filter.get_logs() # type: ignore - return {"result": logs} + _filter = self.async_filters[_filter_id] + if method == RPC.eth_getFilterChanges: + return {"result": await _filter.filter_changes.__anext__()} + + elif method == RPC.eth_getFilterLogs: + # type ignored b/c logic prevents RequestBlocks which + # doesn't implement get_logs + return {"result": await _filter.get_logs()} # type: ignore + else: + raise NotImplementedError(method) else: - raise NotImplementedError(method) - else: - return await make_request(method, params) + return await make_request(method, params) + + return middleware + - return middleware +local_filter_middleware = LocalFilterMiddleware diff --git a/web3/middleware/fixture.py b/web3/middleware/fixture.py deleted file mode 100644 index 6c87ee8f9a..0000000000 --- a/web3/middleware/fixture.py +++ /dev/null @@ -1,190 +0,0 @@ -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Optional, - cast, -) - -from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareCoroutine, - Middleware, - RPCEndpoint, - RPCResponse, -) - -if TYPE_CHECKING: - from web3.main import ( # noqa: F401 - AsyncWeb3, - Web3, - ) - from web3.providers import ( # noqa: F401 - PersistentConnectionProvider, - ) - - -def construct_fixture_middleware(fixtures: Dict[RPCEndpoint, Any]) -> Middleware: - """ - Constructs a middleware which returns a static response for any method - which is found in the provided fixtures. - """ - - def fixture_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], _: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in fixtures: - result = fixtures[method] - return {"result": result} - else: - return make_request(method, params) - - return middleware - - return fixture_middleware - - -def construct_result_generator_middleware( - result_generators: Dict[RPCEndpoint, Any] -) -> Middleware: - """ - Constructs a middleware which intercepts requests for any method found in - the provided mapping of endpoints to generator functions, returning - whatever response the generator function returns. Callbacks must be - functions with the signature `fn(method, params)`. - """ - - def result_generator_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], _: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in result_generators: - result = result_generators[method](method, params) - return {"result": result} - else: - return make_request(method, params) - - return middleware - - return result_generator_middleware - - -def construct_error_generator_middleware( - error_generators: Dict[RPCEndpoint, Any] -) -> Middleware: - """ - Constructs a middleware which intercepts requests for any method found in - the provided mapping of endpoints to generator functions, returning - whatever error message the generator function returns. Callbacks must be - functions with the signature `fn(method, params)`. - """ - - def error_generator_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], _: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in error_generators: - error = error_generators[method](method, params) - if isinstance(error, dict) and error.get("error", False): - return { - "error": { - "code": error.get("code", -32000), - "message": error["error"].get("message", ""), - "data": error.get("data", ""), - } - } - else: - return {"error": error} - else: - return make_request(method, params) - - return middleware - - return error_generator_middleware - - -# --- async --- # - - -async def async_construct_result_generator_middleware( - result_generators: Dict[RPCEndpoint, Any] -) -> AsyncMiddleware: - """ - Constructs a middleware which returns a static response for any method - which is found in the provided fixtures. - """ - - async def result_generator_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" - ) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: - if method in result_generators: - result = result_generators[method](method, params) - - if async_w3.provider.has_persistent_connection: - provider = cast("PersistentConnectionProvider", async_w3.provider) - response = await make_request(method, params) - provider._request_processor.append_middleware_response_processor( - # processed asynchronously later but need to pass the actual - # response to the next middleware - response, - lambda _: {"result": result}, - ) - return response - else: - return {"result": result} - else: - return await make_request(method, params) - - return middleware - - return result_generator_middleware - - -async def async_construct_error_generator_middleware( - error_generators: Dict[RPCEndpoint, Any] -) -> AsyncMiddleware: - """ - Constructs a middleware which intercepts requests for any method found in - the provided mapping of endpoints to generator functions, returning - whatever error message the generator function returns. Callbacks must be - functions with the signature `fn(method, params)`. - """ - - async def error_generator_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" - ) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: - if method in error_generators: - error = error_generators[method](method, params) - if isinstance(error, dict) and error.get("error", False): - error_response = { - "error": { - "code": error.get("code", -32000), - "message": error["error"].get("message", ""), - "data": error.get("data", ""), - } - } - else: - error_response = {"error": error} - - if async_w3.provider.has_persistent_connection: - provider = cast("PersistentConnectionProvider", async_w3.provider) - response = await make_request(method, params) - provider._request_processor.append_middleware_response_processor( - # processed asynchronously later but need to pass the actual - # response to the next middleware - response, - lambda _: error_response, - ) - return response - else: - return cast(RPCResponse, error_response) - else: - return await make_request(method, params) - - return middleware - - return error_generator_middleware diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index 1fddb9cea4..f3b6dcd021 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -4,6 +4,7 @@ Callable, Coroutine, Optional, + Union, cast, ) @@ -13,14 +14,14 @@ merge, ) +from web3.middleware.base import ( + Web3MiddlewareBuilder, +) from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareCoroutine, EthSubscriptionParams, Formatters, FormattersDict, Literal, - Middleware, RPCEndpoint, RPCResponse, ) @@ -30,6 +31,9 @@ AsyncWeb3, Web3, ) + from web3.middleware.base import ( # noqa: F401 + Web3Middleware, + ) from web3.providers import ( # noqa: F401 PersistentConnectionProvider, ) @@ -85,125 +89,137 @@ def _format_response( return response -# --- sync -- # - - -def construct_formatting_middleware( - request_formatters: Optional[Formatters] = None, - result_formatters: Optional[Formatters] = None, - error_formatters: Optional[Formatters] = None, -) -> Middleware: - def ignore_web3_in_standard_formatters( - _w3: "Web3", - _method: RPCEndpoint, - ) -> FormattersDict: - return dict( - request_formatters=request_formatters or {}, - result_formatters=result_formatters or {}, - error_formatters=error_formatters or {}, - ) +SYNC_FORMATTERS_BUILDER = Callable[["Web3", RPCEndpoint], FormattersDict] +ASYNC_FORMATTERS_BUILDER = Callable[ + ["AsyncWeb3", RPCEndpoint], Coroutine[Any, Any, FormattersDict] +] + + +class FormattingMiddlewareBuilder(Web3MiddlewareBuilder): + request_formatters: Formatters = None + result_formatters: Formatters = None + error_formatters: Formatters = None + sync_formatters_builder: SYNC_FORMATTERS_BUILDER = None + async_formatters_builder: ASYNC_FORMATTERS_BUILDER = None + + @staticmethod + @curry + def build( + w3: Union["AsyncWeb3", "Web3"], + # formatters option: + request_formatters: Optional[Formatters] = None, + result_formatters: Optional[Formatters] = None, + error_formatters: Optional[Formatters] = None, + # formatters builder option: + sync_formatters_builder: Optional[SYNC_FORMATTERS_BUILDER] = None, + async_formatters_builder: Optional[ASYNC_FORMATTERS_BUILDER] = None, + ) -> "FormattingMiddlewareBuilder": + # if not both sync and async formatters are specified, raise error + if ( + sync_formatters_builder is None and async_formatters_builder is not None + ) or (sync_formatters_builder is not None and async_formatters_builder is None): + raise ValueError( + "Must specify both sync_formatters_builder and async_formatters_builder" + ) - return construct_web3_formatting_middleware(ignore_web3_in_standard_formatters) + if sync_formatters_builder is not None and async_formatters_builder is not None: + if ( + request_formatters is not None + or result_formatters is not None + or error_formatters is not None + ): + raise ValueError( + "Cannot specify formatters_builder and formatters at the same time" + ) + middleware = FormattingMiddlewareBuilder(w3) + middleware.request_formatters = request_formatters or {} + middleware.result_formatters = result_formatters or {} + middleware.error_formatters = error_formatters or {} + middleware.sync_formatters_builder = sync_formatters_builder + middleware.async_formatters_builder = async_formatters_builder + return middleware -def construct_web3_formatting_middleware( - web3_formatters_builder: Callable[["Web3", RPCEndpoint], FormattersDict], -) -> Middleware: - def formatter_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], - w3: "Web3", - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if self.sync_formatters_builder is not None: formatters = merge( FORMATTER_DEFAULTS, - web3_formatters_builder(w3, method), - ) - request_formatters = formatters.pop("request_formatters") - - if method in request_formatters: - formatter = request_formatters[method] - params = formatter(params) - response = make_request(method, params) - - return _apply_response_formatters( - method, - formatters["result_formatters"], - formatters["error_formatters"], - response, + self.sync_formatters_builder(cast("Web3", self._w3), method), ) + self.request_formatters = formatters.pop("request_formatters") - return middleware - - return formatter_middleware + if method in self.request_formatters: + formatter = self.request_formatters[method] + params = formatter(params) + return method, params -# --- async --- # + def response_processor(self, method: RPCEndpoint, response: "RPCResponse") -> Any: + if self.sync_formatters_builder is not None: + formatters = merge( + FORMATTER_DEFAULTS, + self.sync_formatters_builder(cast("Web3", self._w3), method), + ) + self.result_formatters = formatters["result_formatters"] + self.error_formatters = formatters["error_formatters"] + + return _apply_response_formatters( + method, + self.result_formatters, + self.error_formatters, + response, + ) + # -- async -- # -async def async_construct_formatting_middleware( - request_formatters: Optional[Formatters] = None, - result_formatters: Optional[Formatters] = None, - error_formatters: Optional[Formatters] = None, -) -> AsyncMiddleware: - async def ignore_web3_in_standard_formatters( - _async_w3: "AsyncWeb3", - _method: RPCEndpoint, - ) -> FormattersDict: - return dict( - request_formatters=request_formatters or {}, - result_formatters=result_formatters or {}, - error_formatters=error_formatters or {}, - ) + async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if self.async_formatters_builder is not None: + formatters = merge( + FORMATTER_DEFAULTS, + await self.async_formatters_builder( + cast("AsyncWeb3", self._w3), method + ), + ) + self.request_formatters = formatters.pop("request_formatters") - return await async_construct_web3_formatting_middleware( - ignore_web3_in_standard_formatters - ) + if method in self.request_formatters: + formatter = self.request_formatters[method] + params = formatter(params) + return method, params -async def async_construct_web3_formatting_middleware( - async_web3_formatters_builder: Callable[ - ["AsyncWeb3", RPCEndpoint], Coroutine[Any, Any, FormattersDict] - ] -) -> Callable[ - [Callable[[RPCEndpoint, Any], Any], "AsyncWeb3"], - Coroutine[Any, Any, AsyncMiddlewareCoroutine], -]: - async def formatter_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], - async_w3: "AsyncWeb3", - ) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: + async def async_response_processor( + self, method: RPCEndpoint, response: "RPCResponse" + ) -> Any: + if self.async_formatters_builder is not None: formatters = merge( FORMATTER_DEFAULTS, - await async_web3_formatters_builder(async_w3, method), + await self.async_formatters_builder( + cast("AsyncWeb3", self._w3), method + ), ) - request_formatters = formatters.pop("request_formatters") - - if method in request_formatters: - formatter = request_formatters[method] - params = formatter(params) - response = await make_request(method, params) - - if async_w3.provider.has_persistent_connection: - # asynchronous response processing - provider = cast("PersistentConnectionProvider", async_w3.provider) - provider._request_processor.append_middleware_response_processor( - response, - _apply_response_formatters( - method, - formatters["result_formatters"], - formatters["error_formatters"], - ), - ) - return response - else: - return _apply_response_formatters( + self.result_formatters = formatters["result_formatters"] + self.error_formatters = formatters["error_formatters"] + + if self._w3.provider.has_persistent_connection: + # asynchronous response processing + provider = cast("PersistentConnectionProvider", self._w3.provider) + provider._request_processor.append_middleware_response_processor( + response, + _apply_response_formatters( method, - formatters["result_formatters"], - formatters["error_formatters"], - response, - ) + self.result_formatters, + self.error_formatters, + ), + ) + return response + else: + return _apply_response_formatters( + method, + self.result_formatters, + self.error_formatters, + response, + ) - return middleware - return formatter_middleware +construct_formatting_middleware = FormattingMiddlewareBuilder.build diff --git a/web3/middleware/gas_price_strategy.py b/web3/middleware/gas_price_strategy.py index 1cd5a31f09..e259b5e054 100644 --- a/web3/middleware/gas_price_strategy.py +++ b/web3/middleware/gas_price_strategy.py @@ -1,7 +1,7 @@ from typing import ( TYPE_CHECKING, Any, - Callable, + cast, ) from eth_utils.toolz import ( @@ -23,11 +23,12 @@ InvalidTransaction, TransactionTypeMismatch, ) +from web3.middleware.base import ( + Web3Middleware, +) from web3.types import ( - AsyncMiddlewareCoroutine, BlockData, RPCEndpoint, - RPCResponse, TxParams, Wei, ) @@ -78,51 +79,41 @@ def validate_transaction_params( return transaction -def gas_price_strategy_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" -) -> Callable[[RPCEndpoint, Any], RPCResponse]: +class GasPriceStrategyMiddleware(Web3Middleware): """ - - Uses a gas price strategy if one is set. This is only supported - for legacy transactions. It is recommended to send dynamic fee - transactions (EIP-1559) whenever possible. + - Uses a gas price strategy if one is set. This is only supported for + legacy transactions. It is recommended to send dynamic fee transactions + (EIP-1559) whenever possible. - Validates transaction params against legacy and dynamic fee txn values. """ - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + def request_processor(self, method: RPCEndpoint, params: Any) -> Any: if method == "eth_sendTransaction": transaction = params[0] - generated_gas_price = w3.eth.generate_gas_price(transaction) + generated_gas_price = self._w3.eth.generate_gas_price(transaction) + w3 = cast("Web3", self._w3) latest_block = w3.eth.get_block("latest") transaction = validate_transaction_params( transaction, latest_block, generated_gas_price ) - return make_request(method, (transaction,)) - return make_request(method, params) + params = (transaction,) - return middleware + return method, params + # -- async -- # -async def async_gas_price_strategy_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" -) -> AsyncMiddlewareCoroutine: - """ - - Uses a gas price strategy if one is set. This is only supported for - legacy transactions. It is recommended to send dynamic fee transactions - (EIP-1559) whenever possible. - - - Validates transaction params against legacy and dynamic fee txn values. - """ - - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + async def async_request_processor(self, method: RPCEndpoint, params: Any) -> Any: if method == "eth_sendTransaction": transaction = params[0] - generated_gas_price = async_w3.eth.generate_gas_price(transaction) - latest_block = await async_w3.eth.get_block("latest") + w3 = cast("AsyncWeb3", self._w3) + generated_gas_price = w3.eth.generate_gas_price(transaction) + latest_block = await w3.eth.get_block("latest") transaction = validate_transaction_params( transaction, latest_block, generated_gas_price ) - return await make_request(method, (transaction,)) - return await make_request(method, params) + params = (transaction,) + return method, params + - return middleware +gas_price_strategy_middleware = GasPriceStrategyMiddleware diff --git a/web3/middleware/geth_poa.py b/web3/middleware/geth_poa.py deleted file mode 100644 index dd80431991..0000000000 --- a/web3/middleware/geth_poa.py +++ /dev/null @@ -1,81 +0,0 @@ -from typing import ( - TYPE_CHECKING, - Any, - Callable, -) - -from eth_utils import ( - is_dict, -) -from eth_utils.curried import ( - apply_formatter_if, - apply_formatters_to_dict, - apply_key_map, - is_null, -) -from eth_utils.toolz import ( - complement, - compose, -) -from hexbytes import ( - HexBytes, -) - -from web3._utils.rpc_abi import ( - RPC, -) -from web3.middleware.formatting import ( - async_construct_formatting_middleware, - construct_formatting_middleware, -) -from web3.types import ( - AsyncMiddlewareCoroutine, - RPCEndpoint, -) - -if TYPE_CHECKING: - from web3 import ( # noqa: F401 - AsyncWeb3, - Web3, - ) - -is_not_null = complement(is_null) - -remap_geth_poa_fields = apply_key_map( - { - "extraData": "proofOfAuthorityData", - } -) - -pythonic_geth_poa = apply_formatters_to_dict( - { - "proofOfAuthorityData": HexBytes, - } -) - -geth_poa_cleanup = compose(pythonic_geth_poa, remap_geth_poa_fields) - - -geth_poa_middleware = construct_formatting_middleware( - result_formatters={ - RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup), - RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup), - }, -) - - -async def async_geth_poa_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "AsyncWeb3" -) -> AsyncMiddlewareCoroutine: - middleware = await async_construct_formatting_middleware( - result_formatters={ - RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup), - RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup), - RPC.eth_subscribe: apply_formatter_if( - is_not_null, - # original call to eth_subscribe returns a string, needs a dict check - apply_formatter_if(is_dict, geth_poa_cleanup), - ), - }, - ) - return await middleware(make_request, w3) diff --git a/web3/middleware/names.py b/web3/middleware/names.py index 71a8cb4421..486f6fa9e2 100644 --- a/web3/middleware/names.py +++ b/web3/middleware/names.py @@ -1,10 +1,10 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Dict, Sequence, Union, + cast, ) from toolz import ( @@ -20,8 +20,6 @@ abi_request_formatters, ) from web3.types import ( - AsyncMiddlewareCoroutine, - Middleware, RPCEndpoint, ) @@ -33,8 +31,11 @@ from .._utils.formatters import ( recursive_map, ) +from .base import ( + Web3Middleware, +) from .formatting import ( - construct_formatting_middleware, + FormattingMiddlewareBuilder, ) if TYPE_CHECKING: @@ -44,18 +45,6 @@ ) -def name_to_address_middleware(w3: "Web3") -> Middleware: - normalizers = [ - abi_ens_resolver(w3), - ] - return construct_formatting_middleware( - request_formatters=abi_request_formatters(normalizers, RPC_ABIS) - ) - - -# -- async -- # - - def _is_logs_subscription_with_optional_args(method: RPCEndpoint, params: Any) -> bool: return method == "eth_subscribe" and len(params) == 2 and params[0] == "logs" @@ -109,11 +98,23 @@ async def async_apply_ens_to_address_conversion( ) -async def async_name_to_address_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], - async_w3: "AsyncWeb3", -) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> Any: +class EnsNameToAddressMiddleware(Web3Middleware): + _formatting_middleware = None + + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if self._formatting_middleware is None: + normalizers = [ + abi_ens_resolver(self._w3), + ] + self._formatting_middleware = FormattingMiddlewareBuilder.build( + request_formatters=abi_request_formatters(normalizers, RPC_ABIS) + ) + + return self._formatting_middleware(self._w3).request_processor(method, params) + + # -- async -- # + + async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: abi_types_for_method = RPC_ABIS.get(method, None) if abi_types_for_method is not None: @@ -121,7 +122,7 @@ async def middleware(method: RPCEndpoint, params: Any) -> Any: # eth_subscribe optional logs params are unique. # Handle them separately here. (formatted_dict,) = await async_apply_ens_to_address_conversion( - async_w3, + cast("AsyncWeb3", self._w3), (params[1],), { "address": "address", @@ -132,10 +133,12 @@ async def middleware(method: RPCEndpoint, params: Any) -> Any: else: params = await async_apply_ens_to_address_conversion( - async_w3, + cast("AsyncWeb3", self._w3), params, abi_types_for_method, ) - return await make_request(method, params) - return middleware + return method, params + + +ens_name_to_address_middleware = EnsNameToAddressMiddleware diff --git a/web3/middleware/normalize_request_parameters.py b/web3/middleware/normalize_request_parameters.py index 93926c99a4..a18041fbd9 100644 --- a/web3/middleware/normalize_request_parameters.py +++ b/web3/middleware/normalize_request_parameters.py @@ -3,9 +3,9 @@ ) from .formatting import ( - construct_formatting_middleware, + FormattingMiddlewareBuilder, ) -request_parameter_normalizer = construct_formatting_middleware( +request_parameter_normalizer = FormattingMiddlewareBuilder.build( request_formatters=METHOD_NORMALIZERS, ) diff --git a/web3/middleware/proof_of_authority.py b/web3/middleware/proof_of_authority.py new file mode 100644 index 0000000000..0bcf1ee42d --- /dev/null +++ b/web3/middleware/proof_of_authority.py @@ -0,0 +1,68 @@ +from typing import ( + TYPE_CHECKING, +) + +from eth_utils import ( + is_dict, +) +from eth_utils.curried import ( + apply_formatter_if, + apply_formatters_to_dict, + apply_key_map, + is_null, +) +from eth_utils.toolz import ( + complement, + compose, +) +from hexbytes import ( + HexBytes, +) + +from web3._utils.rpc_abi import ( + RPC, +) +from web3.middleware.formatting import ( + FormattingMiddlewareBuilder, +) + +if TYPE_CHECKING: + from web3 import ( # noqa: F401 + AsyncWeb3, + Web3, + ) + +is_not_null = complement(is_null) + +remap_extradata_to_poa_fields = apply_key_map( + { + "extraData": "proofOfAuthorityData", + } +) + +pythonic_extradata_to_poa = apply_formatters_to_dict( + { + "proofOfAuthorityData": HexBytes, + } +) + +extradata_to_poa_cleanup = compose( + pythonic_extradata_to_poa, remap_extradata_to_poa_fields +) + + +extradata_to_poa_middleware = FormattingMiddlewareBuilder.build( + result_formatters={ + RPC.eth_getBlockByHash: apply_formatter_if( + is_not_null, extradata_to_poa_cleanup + ), + RPC.eth_getBlockByNumber: apply_formatter_if( + is_not_null, extradata_to_poa_cleanup + ), + RPC.eth_subscribe: apply_formatter_if( + is_not_null, + # original call to eth_subscribe returns a string, needs a dict check + apply_formatter_if(is_dict, extradata_to_poa_cleanup), + ), + }, +) diff --git a/web3/middleware/pythonic.py b/web3/middleware/pythonic.py index 1ffd253e68..cb8856595b 100644 --- a/web3/middleware/pythonic.py +++ b/web3/middleware/pythonic.py @@ -3,10 +3,10 @@ PYTHONIC_RESULT_FORMATTERS, ) from web3.middleware.formatting import ( - construct_formatting_middleware, + FormattingMiddlewareBuilder, ) -pythonic_middleware = construct_formatting_middleware( +pythonic_middleware = FormattingMiddlewareBuilder.build( request_formatters=PYTHONIC_REQUEST_FORMATTERS, result_formatters=PYTHONIC_RESULT_FORMATTERS, ) diff --git a/web3/middleware/signing.py b/web3/middleware/signing.py index 0ccc629f21..2d35932e2b 100644 --- a/web3/middleware/signing.py +++ b/web3/middleware/signing.py @@ -5,12 +5,12 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Collection, Iterable, Tuple, TypeVar, Union, + cast, ) from eth_account import ( @@ -36,6 +36,9 @@ from eth_utils.toolz import ( compose, ) +from toolz import ( + curry, +) from web3._utils.async_transactions import ( async_fill_nonce, @@ -52,12 +55,11 @@ fill_nonce, fill_transaction_defaults, ) +from web3.middleware.base import ( + Web3MiddlewareBuilder, +) from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareCoroutine, - Middleware, RPCEndpoint, - RPCResponse, TxParams, ) @@ -141,99 +143,76 @@ def format_transaction(transaction: TxParams) -> TxParams: ) -def construct_sign_and_send_raw_middleware( - private_key_or_account: Union[_PrivateKey, Collection[_PrivateKey]] -) -> Middleware: - """Capture transactions sign and send as raw transactions - - - Keyword arguments: - private_key_or_account -- A single private key or a tuple, - list or set of private keys. Keys can be any of the following formats: - - An eth_account.LocalAccount object - - An eth_keys.PrivateKey object - - A raw private key as a hex string or byte string - """ - - accounts = gen_normalized_accounts(private_key_or_account) - - def sign_and_send_raw_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - format_and_fill_tx = compose( - format_transaction, fill_transaction_defaults(w3), fill_nonce(w3) - ) - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method != "eth_sendTransaction": - return make_request(method, params) - else: - transaction = format_and_fill_tx(params[0]) - - if "from" not in transaction: - return make_request(method, params) - elif transaction.get("from") not in accounts: - return make_request(method, params) - - account = accounts[transaction["from"]] - raw_tx = account.sign_transaction(transaction).rawTransaction - - return make_request(RPCEndpoint("eth_sendRawTransaction"), [raw_tx.hex()]) +class SignAndSendRawMiddlewareBuilder(Web3MiddlewareBuilder): + _accounts = None + format_and_fill_tx = None + @staticmethod + @curry + def build( + private_key_or_account: Union[_PrivateKey, Collection[_PrivateKey]], + w3: Union["Web3", "AsyncWeb3"], + ) -> "SignAndSendRawMiddlewareBuilder": + middleware = SignAndSendRawMiddlewareBuilder(w3) + middleware._accounts = gen_normalized_accounts(private_key_or_account) return middleware - return sign_and_send_raw_middleware + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if method != "eth_sendTransaction": + return method, params + else: + w3 = cast("Web3", self._w3) + if self.format_and_fill_tx is None: + self.format_and_fill_tx = compose( + format_transaction, + fill_transaction_defaults(w3), + fill_nonce(w3), + ) + + filled_transaction = self.format_and_fill_tx(params[0]) + tx_from = filled_transaction.get("from", None) + if tx_from is None or ( + tx_from is not None and tx_from not in self._accounts + ): + return method, params + else: + account = self._accounts[to_checksum_address(tx_from)] + raw_tx = account.sign_transaction(filled_transaction).rawTransaction -# -- async -- # + return ( + RPCEndpoint("eth_sendRawTransaction"), + [raw_tx.hex()], + ) + # -- async -- # -async def async_construct_sign_and_send_raw_middleware( - private_key_or_account: Union[_PrivateKey, Collection[_PrivateKey]] -) -> AsyncMiddleware: - """ - Capture transactions & sign and send as raw transactions - - Keyword arguments: - private_key_or_account -- A single private key or a tuple, - list or set of private keys. Keys can be any of the following formats: - - An eth_account.LocalAccount object - - An eth_keys.PrivateKey object - - A raw private key as a hex string or byte string - """ - accounts = gen_normalized_accounts(private_key_or_account) + async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if method != "eth_sendTransaction": + return method, params - async def async_sign_and_send_raw_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" - ) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method != "eth_sendTransaction": - # quick exit if not `eth_sendTransaction` - return await make_request(method, params) + else: + w3 = cast("AsyncWeb3", self._w3) formatted_transaction = format_transaction(params[0]) filled_transaction = await async_fill_transaction_defaults( - async_w3, - formatted_transaction, - ) - filled_transaction = await async_fill_nonce( - async_w3, - filled_transaction, + w3, formatted_transaction ) - + filled_transaction = await async_fill_nonce(w3, filled_transaction) tx_from = filled_transaction.get("from", None) - if tx_from is None or (tx_from is not None and tx_from not in accounts): - return await make_request(method, params) + if tx_from is None or ( + tx_from is not None and tx_from not in self._accounts + ): + return method, params + else: + account = self._accounts[to_checksum_address(tx_from)] + raw_tx = account.sign_transaction(filled_transaction).rawTransaction - account = accounts[to_checksum_address(tx_from)] - raw_tx = account.sign_transaction(filled_transaction).rawTransaction + return ( + RPCEndpoint("eth_sendRawTransaction"), + [raw_tx.hex()], + ) - return await make_request( - RPCEndpoint("eth_sendRawTransaction"), - [raw_tx.hex()], - ) - - return middleware - return async_sign_and_send_raw_middleware +construct_sign_and_send_raw_middleware = SignAndSendRawMiddlewareBuilder.build diff --git a/web3/middleware/simulate_unmined_transaction.py b/web3/middleware/simulate_unmined_transaction.py deleted file mode 100644 index 18e2f8bf2c..0000000000 --- a/web3/middleware/simulate_unmined_transaction.py +++ /dev/null @@ -1,43 +0,0 @@ -import collections -import itertools -from typing import ( - Any, - Callable, -) - -from eth_typing import ( - Hash32, -) - -from web3 import ( - Web3, -) -from web3.types import ( - RPCEndpoint, - RPCResponse, - TxReceipt, -) - -counter = itertools.count() - -INVOCATIONS_BEFORE_RESULT = 5 - - -def unmined_receipt_simulator_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: Web3 -) -> Callable[[RPCEndpoint, Any], RPCResponse]: - receipt_counters: DefaultDict[Hash32, TxReceipt] = collections.defaultdict( # type: ignore # noqa: F821, E501 - itertools.count - ) - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method == "eth_getTransactionReceipt": - txn_hash = params[0] - if next(receipt_counters[txn_hash]) < INVOCATIONS_BEFORE_RESULT: - return {"result": None} - else: - return make_request(method, params) - else: - return make_request(method, params) - - return middleware diff --git a/web3/middleware/stalecheck.py b/web3/middleware/stalecheck.py index 16b98715a3..e5eaedefa1 100644 --- a/web3/middleware/stalecheck.py +++ b/web3/middleware/stalecheck.py @@ -6,18 +6,24 @@ Collection, Dict, Optional, + Union, + cast, +) + +from toolz import ( + curry, ) from web3.exceptions import ( StaleBlockchain, ) +from web3.middleware.base import ( + Web3Middleware, + Web3MiddlewareBuilder, +) from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareCoroutine, BlockData, - Middleware, RPCEndpoint, - RPCResponse, ) if TYPE_CHECKING: @@ -35,86 +41,55 @@ def _is_fresh(block: BlockData, allowable_delay: int) -> bool: return False -def make_stalecheck_middleware( - allowable_delay: int, - skip_stalecheck_for_methods: Collection[str] = SKIP_STALECHECK_FOR_METHODS, -) -> Middleware: - """ - Use to require that a function will run only of the blockchain is recently updated. - - This middleware takes an argument, so unlike other middleware, you must make the - middleware with a method call. - - For example: `make_stalecheck_middleware(60*5)` - - If the latest block in the chain is older than 5 minutes in this example, then the - middleware will raise a StaleBlockchain exception. - """ - if allowable_delay <= 0: - raise ValueError( - "You must set a positive allowable_delay in seconds for this middleware" - ) - - def stalecheck_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - cache: Dict[str, Optional[BlockData]] = {"latest": None} - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method not in skip_stalecheck_for_methods: - if not _is_fresh(cache["latest"], allowable_delay): - latest = w3.eth.get_block("latest") - if _is_fresh(latest, allowable_delay): - cache["latest"] = latest - else: - raise StaleBlockchain(latest, allowable_delay) - - return make_request(method, params) - +class StaleCheckMiddlewareBuilder(Web3MiddlewareBuilder): + allowable_delay: int + skip_stalecheck_for_methods: Collection[str] + cache: Dict[str, Optional[BlockData]] + + @staticmethod + @curry + def build( + allowable_delay: int, + w3: Union["Web3", "AsyncWeb3"], + skip_stalecheck_for_methods: Collection[str] = SKIP_STALECHECK_FOR_METHODS, + ) -> Web3Middleware: + if allowable_delay <= 0: + raise ValueError( + "You must set a positive allowable_delay in seconds for this middleware" + ) + middleware = StaleCheckMiddlewareBuilder(w3) + middleware.allowable_delay = allowable_delay + middleware.skip_stalecheck_for_methods = skip_stalecheck_for_methods + middleware.cache = {"latest": None} return middleware - return stalecheck_middleware - - -# -- async -- # + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if method not in self.skip_stalecheck_for_methods: + if not _is_fresh(self.cache["latest"], self.allowable_delay): + w3 = cast("Web3", self._w3) + latest = w3.eth.get_block("latest") + if _is_fresh(latest, self.allowable_delay): + self.cache["latest"] = latest + else: + raise StaleBlockchain(latest, self.allowable_delay) -async def async_make_stalecheck_middleware( - allowable_delay: int, - skip_stalecheck_for_methods: Collection[str] = SKIP_STALECHECK_FOR_METHODS, -) -> AsyncMiddleware: - """ - Use to require that a function will run only of the blockchain is recently updated. + return method, params - This middleware takes an argument, so unlike other middleware, you must make the - middleware with a method call. + # -- async -- # - For example: `async_make_stalecheck_middleware(60*5)` + async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if method not in self.skip_stalecheck_for_methods: + if not _is_fresh(self.cache["latest"], self.allowable_delay): + w3 = cast("AsyncWeb3", self._w3) + latest = await w3.eth.get_block("latest") - If the latest block in the chain is older than 5 minutes in this example, then the - middleware will raise a StaleBlockchain exception. - """ - if allowable_delay <= 0: - raise ValueError( - "You must set a positive allowable_delay in seconds for this middleware" - ) + if _is_fresh(latest, self.allowable_delay): + self.cache["latest"] = latest + else: + raise StaleBlockchain(latest, self.allowable_delay) - async def stalecheck_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "AsyncWeb3" - ) -> AsyncMiddlewareCoroutine: - cache: Dict[str, Optional[BlockData]] = {"latest": None} + return method, params - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method not in skip_stalecheck_for_methods: - if not _is_fresh(cache["latest"], allowable_delay): - latest = await w3.eth.get_block("latest") - if _is_fresh(latest, allowable_delay): - cache["latest"] = latest - else: - raise StaleBlockchain(latest, allowable_delay) - - return await make_request(method, params) - - return middleware - return stalecheck_middleware +make_stalecheck_middleware = StaleCheckMiddlewareBuilder.build diff --git a/web3/middleware/validation.py b/web3/middleware/validation.py index 9686a86715..a8660e5e8f 100644 --- a/web3/middleware/validation.py +++ b/web3/middleware/validation.py @@ -33,11 +33,9 @@ Web3ValidationError, ) from web3.middleware.formatting import ( - async_construct_web3_formatting_middleware, - construct_web3_formatting_middleware, + FormattingMiddlewareBuilder, ) from web3.types import ( - AsyncMiddlewareCoroutine, Formatters, FormattersDict, RPCEndpoint, @@ -147,9 +145,6 @@ def build_method_validators(w3: "Web3", method: RPCEndpoint) -> FormattersDict: return _build_formatters_dict(request_formatters) -validation_middleware = construct_web3_formatting_middleware(build_method_validators) - - # -- async --- # @@ -165,10 +160,7 @@ async def async_build_method_validators( return _build_formatters_dict(request_formatters) -async def async_validation_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "AsyncWeb3" -) -> AsyncMiddlewareCoroutine: - middleware = await async_construct_web3_formatting_middleware( - async_build_method_validators - ) - return await middleware(make_request, w3) +validation_middleware = FormattingMiddlewareBuilder.build( + sync_formatters_builder=build_method_validators, + async_formatters_builder=async_build_method_validators, +) diff --git a/web3/module.py b/web3/module.py index 41ad2ca156..e04cf5ea31 100644 --- a/web3/module.py +++ b/web3/module.py @@ -103,7 +103,11 @@ async def caller(*args: Any, **kwargs: Any) -> Union[RPCResponse, AsyncLogFilter method_str = cast(RPCEndpoint, method_str) return await async_w3.manager.ws_send(method_str, params) except Exception as e: - if cache_key in provider._request_processor._request_information_cache: + if ( + cache_key is not None + and cache_key + in provider._request_processor._request_information_cache + ): provider._request_processor.pop_cached_request_information( cache_key ) diff --git a/web3/providers/__init__.py b/web3/providers/__init__.py index 883b88060e..e3a74a4d93 100644 --- a/web3/providers/__init__.py +++ b/web3/providers/__init__.py @@ -1,7 +1,7 @@ from .async_base import ( AsyncBaseProvider, ) -from .async_rpc import ( +from .rpc import ( AsyncHTTPProvider, ) from .base import ( diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py index 30404b6b7c..471e44e66e 100644 --- a/web3/providers/async_base.py +++ b/web3/providers/async_base.py @@ -1,10 +1,11 @@ +import asyncio import itertools from typing import ( TYPE_CHECKING, Any, Callable, Coroutine, - Sequence, + Set, Tuple, cast, ) @@ -15,6 +16,9 @@ to_text, ) +from web3._utils.caching import ( + async_handle_request_caching, +) from web3._utils.encoding import ( FriendlyJsonSerde, Web3JsonEncoder, @@ -25,13 +29,17 @@ from web3.middleware import ( async_combine_middlewares, ) -from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareOnion, +from web3.middleware.base import ( + Middleware, MiddlewareOnion, +) +from web3.types import ( RPCEndpoint, RPCResponse, ) +from web3.utils import ( + SimpleCache, +) if TYPE_CHECKING: from web3 import ( # noqa: F401 @@ -40,53 +48,60 @@ ) +CACHEABLE_REQUESTS = cast( + Set[RPCEndpoint], + ( + "eth_chainId", + "eth_getBlockByHash", + "eth_getBlockTransactionCountByHash", + "eth_getRawTransactionByHash", + "eth_getTransactionByBlockHashAndIndex", + "eth_getTransactionByHash", + "eth_getUncleByBlockHashAndIndex", + "eth_getUncleCountByBlockHash", + "net_version", + "web3_clientVersion", + ), +) + + class AsyncBaseProvider: - _middlewares: Tuple[AsyncMiddleware, ...] = () - # a tuple of (all_middlewares, request_func) _request_func_cache: Tuple[ - Tuple[AsyncMiddleware, ...], Callable[..., Coroutine[Any, Any, RPCResponse]] - ] = ( - None, - None, - ) + Tuple[Middleware, ...], Callable[..., Coroutine[Any, Any, RPCResponse]] + ] = (None, None) is_async = True has_persistent_connection = False global_ccip_read_enabled: bool = True ccip_read_max_redirects: int = 4 - @property - def middlewares(self) -> Tuple[AsyncMiddleware, ...]: - return self._middlewares + # request caching + cache_allowed_requests: bool = False + cacheable_requests: Set[RPCEndpoint] = CACHEABLE_REQUESTS + _request_cache: SimpleCache + _request_cache_lock: asyncio.Lock = asyncio.Lock() - @middlewares.setter - def middlewares(self, values: MiddlewareOnion) -> None: - # tuple(values) converts to MiddlewareOnion -> Tuple[Middleware, ...] - self._middlewares = tuple(values) # type: ignore + def __init__(self) -> None: + self._request_cache = SimpleCache(1000) async def request_func( - self, async_w3: "AsyncWeb3", outer_middlewares: AsyncMiddlewareOnion + self, async_w3: "AsyncWeb3", middleware_onion: MiddlewareOnion ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: - # type ignored b/c tuple(MiddlewareOnion) converts to tuple of middlewares - all_middlewares: Tuple[AsyncMiddleware] = tuple(outer_middlewares) + tuple(self.middlewares) # type: ignore # noqa: E501 + middlewares: Tuple[Middleware, ...] = middleware_onion.as_tuple_of_middlewares() cache_key = self._request_func_cache[0] - if cache_key is None or cache_key != all_middlewares: + if cache_key != middlewares: self._request_func_cache = ( - all_middlewares, - await self._generate_request_func(async_w3, all_middlewares), + middlewares, + await async_combine_middlewares( + middlewares=middlewares, + async_w3=async_w3, + provider_request_fn=self.make_request, + ), ) return self._request_func_cache[-1] - async def _generate_request_func( - self, async_w3: "AsyncWeb3", middlewares: Sequence[AsyncMiddleware] - ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: - return await async_combine_middlewares( - middlewares=middlewares, - async_w3=async_w3, - provider_request_fn=self.make_request, - ) - + @async_handle_request_caching async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: raise NotImplementedError("Providers must implement this method") diff --git a/web3/providers/base.py b/web3/providers/base.py index d7877546cd..ba5d4c5962 100644 --- a/web3/providers/base.py +++ b/web3/providers/base.py @@ -1,9 +1,10 @@ import itertools +import threading from typing import ( TYPE_CHECKING, Any, Callable, - Sequence, + Set, Tuple, cast, ) @@ -13,6 +14,9 @@ to_text, ) +from web3._utils.caching import ( + handle_request_caching, +) from web3._utils.encoding import ( FriendlyJsonSerde, Web3JsonEncoder, @@ -23,20 +27,41 @@ from web3.middleware import ( combine_middlewares, ) -from web3.types import ( +from web3.middleware.base import ( Middleware, MiddlewareOnion, +) +from web3.types import ( RPCEndpoint, RPCResponse, ) +from web3.utils import ( + SimpleCache, +) if TYPE_CHECKING: from web3 import Web3 # noqa: F401 +CACHEABLE_REQUESTS = cast( + Set[RPCEndpoint], + ( + "eth_chainId", + "eth_getBlockByHash", + "eth_getBlockTransactionCountByHash", + "eth_getRawTransactionByHash", + "eth_getTransactionByBlockHashAndIndex", + "eth_getTransactionByHash", + "eth_getUncleByBlockHashAndIndex", + "eth_getUncleCountByBlockHash", + "net_version", + "web3_clientVersion", + ), +) + + class BaseProvider: - _middlewares: Tuple[Middleware, ...] = () - # a tuple of (all_middlewares, request_func) + # a tuple of (middlewares, request_func) _request_func_cache: Tuple[Tuple[Middleware, ...], Callable[..., RPCResponse]] = ( None, None, @@ -47,44 +72,41 @@ class BaseProvider: global_ccip_read_enabled: bool = True ccip_read_max_redirects: int = 4 - @property - def middlewares(self) -> Tuple[Middleware, ...]: - return self._middlewares + # request caching + cache_allowed_requests: bool = False + cacheable_requests: Set[RPCEndpoint] = CACHEABLE_REQUESTS + _request_cache: SimpleCache + _request_cache_lock: threading.Lock = threading.Lock() - @middlewares.setter - def middlewares(self, values: MiddlewareOnion) -> None: - # tuple(values) converts to MiddlewareOnion -> Tuple[Middleware, ...] - self._middlewares = tuple(values) # type: ignore + def __init__(self) -> None: + self._request_cache = SimpleCache(1000) def request_func( - self, w3: "Web3", outer_middlewares: MiddlewareOnion + self, w3: "Web3", middleware_onion: MiddlewareOnion ) -> Callable[..., RPCResponse]: """ - @param outer_middlewares is an iterable of middlewares, + @param w3 is the web3 instance + @param middleware_onion is an iterable of middlewares, ordered by first to execute @returns a function that calls all the middleware and eventually self.make_request() """ - # type ignored b/c tuple(MiddlewareOnion) converts to tuple of middlewares - all_middlewares: Tuple[Middleware] = tuple(outer_middlewares) + tuple(self.middlewares) # type: ignore # noqa: E501 + middlewares: Tuple[Middleware, ...] = middleware_onion.as_tuple_of_middlewares() cache_key = self._request_func_cache[0] - if cache_key is None or cache_key != all_middlewares: + if cache_key != middlewares: self._request_func_cache = ( - all_middlewares, - self._generate_request_func(w3, all_middlewares), + middlewares, + combine_middlewares( + middlewares=middlewares, + w3=w3, + provider_request_fn=self.make_request, + ), ) - return self._request_func_cache[-1] - def _generate_request_func( - self, w3: "Web3", middlewares: Sequence[Middleware] - ) -> Callable[..., RPCResponse]: - return combine_middlewares( - middlewares=middlewares, - w3=w3, - provider_request_fn=self.make_request, - ) + return self._request_func_cache[-1] + @handle_request_caching def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: raise NotImplementedError("Providers must implement this method") @@ -95,6 +117,7 @@ def is_connected(self, show_traceback: bool = False) -> bool: class JSONBaseProvider(BaseProvider): def __init__(self) -> None: self.request_counter = itertools.count() + super().__init__() def decode_rpc_response(self, raw_response: bytes) -> RPCResponse: text_response = to_text(raw_response) diff --git a/web3/providers/eth_tester/main.py b/web3/providers/eth_tester/main.py index 069aee3603..70cb3f7104 100644 --- a/web3/providers/eth_tester/main.py +++ b/web3/providers/eth_tester/main.py @@ -2,6 +2,7 @@ TYPE_CHECKING, Any, Callable, + Coroutine, Dict, Optional, Union, @@ -20,13 +21,6 @@ from web3._utils.compat import ( Literal, ) -from web3.middleware.attrdict import ( - async_attrdict_middleware, - attrdict_middleware, -) -from web3.middleware.buffered_gas_estimate import ( - async_buffered_gas_estimate_middleware, -) from web3.providers import ( BaseProvider, ) @@ -39,9 +33,11 @@ RPCResponse, ) +from ...middleware import ( + async_combine_middlewares, + combine_middlewares, +) from .middleware import ( - async_default_transaction_fields_middleware, - async_ethereum_tester_middleware, default_transaction_fields_middleware, ethereum_tester_middleware, ) @@ -50,13 +46,21 @@ from eth_tester import EthereumTester # noqa: F401 from eth_tester.backends.base import BaseChainBackend # noqa: F401 + from web3 import ( # noqa: F401 + AsyncWeb3, + Web3, + ) + from web3.middleware.base import ( # noqa: F401 + Middleware, + MiddlewareOnion, + Web3Middleware, + ) + class AsyncEthereumTesterProvider(AsyncBaseProvider): - middlewares = ( - async_attrdict_middleware, - async_buffered_gas_estimate_middleware, - async_default_transaction_fields_middleware, - async_ethereum_tester_middleware, + _middlewares = ( + default_transaction_fields_middleware, + ethereum_tester_middleware, ) def __init__(self) -> None: @@ -74,6 +78,27 @@ def __init__(self) -> None: self.ethereum_tester = EthereumTester() self.api_endpoints = API_ENDPOINTS + async def request_func( + self, async_w3: "AsyncWeb3", middleware_onion: "MiddlewareOnion" + ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: + # override the request_func to add the ethereum_tester_middleware + + middlewares = middleware_onion.as_tuple_of_middlewares() + tuple( + self._middlewares + ) + + cache_key = self._request_func_cache[0] + if cache_key != middlewares: + self._request_func_cache = ( + middlewares, + await async_combine_middlewares( + middlewares=middlewares, + async_w3=async_w3, + provider_request_fn=self.make_request, + ), + ) + return self._request_func_cache[-1] + async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: return _make_request(method, params, self.api_endpoints, self.ethereum_tester) @@ -82,8 +107,7 @@ async def is_connected(self, show_traceback: bool = False) -> Literal[True]: class EthereumTesterProvider(BaseProvider): - middlewares = ( - attrdict_middleware, + _middlewares = ( default_transaction_fields_middleware, ethereum_tester_middleware, ) @@ -98,6 +122,7 @@ def __init__( ] = None, ) -> None: # do not import eth_tester until runtime, it is not a default dependency + super().__init__() from eth_tester import EthereumTester # noqa: F811 from eth_tester.backends.base import ( BaseChainBackend, @@ -129,6 +154,27 @@ def __init__( else: self.api_endpoints = api_endpoints + def request_func( + self, w3: "Web3", middleware_onion: "MiddlewareOnion" + ) -> Callable[..., RPCResponse]: + # override the request_func to add the ethereum_tester_middleware + + middlewares = middleware_onion.as_tuple_of_middlewares() + tuple( + self._middlewares + ) + + cache_key = self._request_func_cache[0] + if cache_key != middlewares: + self._request_func_cache = ( + middlewares, + combine_middlewares( + middlewares=middlewares, + w3=w3, + provider_request_fn=self.make_request, + ), + ) + return self._request_func_cache[-1] + def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: return _make_request(method, params, self.api_endpoints, self.ethereum_tester) diff --git a/web3/providers/eth_tester/middleware.py b/web3/providers/eth_tester/middleware.py index a50e949b2a..bb8c8c3576 100644 --- a/web3/providers/eth_tester/middleware.py +++ b/web3/providers/eth_tester/middleware.py @@ -40,17 +40,14 @@ from web3._utils.method_formatters import ( apply_list_to_array_formatter, ) -from web3.middleware import ( - construct_formatting_middleware, +from web3.middleware.base import ( + Web3Middleware, ) from web3.middleware.formatting import ( - async_construct_formatting_middleware, + FormattingMiddlewareBuilder, ) from web3.types import ( - AsyncMiddlewareCoroutine, - Middleware, RPCEndpoint, - RPCResponse, TxParams, ) @@ -192,7 +189,6 @@ def is_hexstr(value: Any) -> bool: block_result_remapper = apply_key_map(BLOCK_RESULT_KEY_MAPPING) BLOCK_RESULT_FORMATTERS = { - "logsBloom": integer_to_hex, "withdrawals": apply_list_to_array_formatter( apply_key_map({"validator_index": "validatorIndex"}), ), @@ -273,12 +269,10 @@ def is_hexstr(value: Any) -> bool: } result_formatters: Optional[Dict[RPCEndpoint, Callable[..., Any]]] = { RPCEndpoint("eth_getBlockByHash"): apply_formatter_if( - is_dict, - compose(block_result_remapper, block_result_formatter), + is_dict, compose(block_result_remapper, block_result_formatter) ), RPCEndpoint("eth_getBlockByNumber"): apply_formatter_if( - is_dict, - compose(block_result_remapper, block_result_formatter), + is_dict, compose(block_result_remapper, block_result_formatter) ), RPCEndpoint("eth_getBlockTransactionCountByHash"): apply_formatter_if( is_dict, @@ -316,11 +310,6 @@ def is_hexstr(value: Any) -> bool: } -ethereum_tester_middleware = construct_formatting_middleware( - request_formatters=request_formatters, result_formatters=result_formatters -) - - def guess_from(w3: "Web3", _: TxParams) -> ChecksumAddress: if w3.eth.coinbase: return w3.eth.coinbase @@ -342,40 +331,9 @@ def fill_default( return assoc(transaction, field, guess_val) -def default_transaction_fields_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" -) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in ( - "eth_call", - "eth_estimateGas", - "eth_sendTransaction", - "eth_createAccessList", - ): - fill_default_from = fill_default("from", guess_from, w3) - filled_transaction = pipe( - params[0], - fill_default_from, - ) - return make_request(method, [filled_transaction] + list(params)[1:]) - else: - return make_request(method, params) - - return middleware - - # --- async --- # -async def async_ethereum_tester_middleware( # type: ignore - make_request, web3: "AsyncWeb3" -) -> Middleware: - middleware = await async_construct_formatting_middleware( - request_formatters=request_formatters, result_formatters=result_formatters - ) - return await middleware(make_request, web3) - - async def async_guess_from( async_w3: "AsyncWeb3", _: TxParams ) -> Optional[ChecksumAddress]: @@ -403,20 +361,43 @@ async def async_fill_default( return assoc(transaction, field, guess_val) -async def async_default_transaction_fields_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" -) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: +# --- define middleware --- # + + +class DefaultTransactionFieldsMiddleware(Web3Middleware): + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if method in ( + "eth_call", + "eth_estimateGas", + "eth_sendTransaction", + "eth_createAccessList", + ): + fill_default_from = fill_default("from", guess_from, self._w3) + filled_transaction = pipe( + params[0], + fill_default_from, + ) + params = [filled_transaction] + list(params)[1:] + return method, params + + # --- async --- # + + async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method in ( "eth_call", "eth_estimateGas", "eth_sendTransaction", + "eth_createAccessList", ): filled_transaction = await async_fill_default( - "from", async_guess_from, async_w3, params[0] + "from", async_guess_from, self._w3, params[0] ) - return await make_request(method, [filled_transaction] + list(params)[1:]) - else: - return await make_request(method, params) + params = [filled_transaction] + list(params)[1:] + + return method, params + - return middleware +ethereum_tester_middleware = FormattingMiddlewareBuilder.build( + request_formatters=request_formatters, result_formatters=result_formatters +) +default_transaction_fields_middleware = DefaultTransactionFieldsMiddleware diff --git a/web3/providers/rpc/__init__.py b/web3/providers/rpc/__init__.py new file mode 100644 index 0000000000..cbac394616 --- /dev/null +++ b/web3/providers/rpc/__init__.py @@ -0,0 +1,6 @@ +from .async_rpc import ( + AsyncHTTPProvider, +) +from .rpc import ( + HTTPProvider, +) diff --git a/web3/providers/async_rpc.py b/web3/providers/rpc/async_rpc.py similarity index 55% rename from web3/providers/async_rpc.py rename to web3/providers/rpc/async_rpc.py index e5ab8cb79d..1fcd6065ef 100644 --- a/web3/providers/async_rpc.py +++ b/web3/providers/rpc/async_rpc.py @@ -1,3 +1,4 @@ +import asyncio import logging from typing import ( Any, @@ -27,33 +28,32 @@ get_default_http_endpoint, ) from web3.types import ( - AsyncMiddleware, RPCEndpoint, RPCResponse, ) -from ..datastructures import ( - NamedElementOnion, +from ..._utils.caching import ( + async_handle_request_caching, ) -from ..middleware.exception_retry_request import ( - async_http_retry_request_middleware, -) -from .async_base import ( +from ..async_base import ( AsyncJSONBaseProvider, ) +from .utils import ( + ExceptionRetryConfiguration, + check_if_retry_on_failure, +) class AsyncHTTPProvider(AsyncJSONBaseProvider): logger = logging.getLogger("web3.providers.AsyncHTTPProvider") endpoint_uri = None _request_kwargs = None - # type ignored b/c conflict with _middlewares attr on AsyncBaseProvider - _middlewares: Tuple[AsyncMiddleware, ...] = NamedElementOnion([(async_http_retry_request_middleware, "http_retry_request")]) # type: ignore # noqa: E501 def __init__( self, endpoint_uri: Optional[Union[URI, str]] = None, request_kwargs: Optional[Any] = None, + exception_retry_configuration: Optional[ExceptionRetryConfiguration] = None, ) -> None: if endpoint_uri is None: self.endpoint_uri = get_default_http_endpoint() @@ -62,6 +62,12 @@ def __init__( self._request_kwargs = request_kwargs or {} + self.exception_retry_configuration = ( + exception_retry_configuration + # use default values if not provided + or ExceptionRetryConfiguration() + ) + super().__init__() async def cache_async_session(self, session: ClientSession) -> ClientSession: @@ -83,14 +89,43 @@ def get_request_headers(self) -> Dict[str, str]: "User-Agent": construct_user_agent(str(type(self))), } + async def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes: + """ + If exception_retry_configuration is set, retry on failure; otherwise, make + the request without retrying. + """ + if ( + self.exception_retry_configuration is not None + and check_if_retry_on_failure( + method, self.exception_retry_configuration.method_allowlist + ) + ): + for i in range(self.exception_retry_configuration.retries): + try: + return await async_make_post_request( + self.endpoint_uri, request_data, **self.get_request_kwargs() + ) + except tuple(self.exception_retry_configuration.errors): + if i < self.exception_retry_configuration.retries - 1: + await asyncio.sleep( + self.exception_retry_configuration.backoff_factor + ) + continue + else: + raise + return None + else: + return await async_make_post_request( + self.endpoint_uri, request_data, **self.get_request_kwargs() + ) + + @async_handle_request_caching async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: self.logger.debug( f"Making request HTTP. URI: {self.endpoint_uri}, Method: {method}" ) request_data = self.encode_rpc_request(method, params) - raw_response = await async_make_post_request( - self.endpoint_uri, request_data, **self.get_request_kwargs() - ) + raw_response = await self._make_request(method, request_data) response = self.decode_rpc_response(raw_response) self.logger.debug( f"Getting response HTTP. URI: {self.endpoint_uri}, " diff --git a/web3/providers/rpc.py b/web3/providers/rpc/rpc.py similarity index 54% rename from web3/providers/rpc.py rename to web3/providers/rpc/rpc.py index bffa69392a..c2c9139de9 100644 --- a/web3/providers/rpc.py +++ b/web3/providers/rpc/rpc.py @@ -1,5 +1,7 @@ import logging +import time from typing import ( + TYPE_CHECKING, Any, Dict, Iterable, @@ -23,36 +25,45 @@ get_default_http_endpoint, make_post_request, ) -from web3.datastructures import ( - NamedElementOnion, -) -from web3.middleware import ( - http_retry_request_middleware, -) from web3.types import ( - Middleware, RPCEndpoint, RPCResponse, ) -from .base import ( +from ..._utils.caching import ( + handle_request_caching, +) +from ..base import ( JSONBaseProvider, ) +from .utils import ( + ExceptionRetryConfiguration, + check_if_retry_on_failure, +) + +if TYPE_CHECKING: + from web3.middleware.base import ( # noqa: F401 + Middleware, + ) class HTTPProvider(JSONBaseProvider): logger = logging.getLogger("web3.providers.HTTPProvider") endpoint_uri = None + _request_args = None _request_kwargs = None - # type ignored b/c conflict with _middlewares attr on BaseProvider - _middlewares: Tuple[Middleware, ...] = NamedElementOnion([(http_retry_request_middleware, "http_retry_request")]) # type: ignore # noqa: E501 + + exception_retry_configuration: Optional[ExceptionRetryConfiguration] = None def __init__( self, endpoint_uri: Optional[Union[URI, str]] = None, request_kwargs: Optional[Any] = None, session: Optional[Any] = None, + exception_retry_configuration: Optional[ + ExceptionRetryConfiguration + ] = ExceptionRetryConfiguration(), ) -> None: if endpoint_uri is None: self.endpoint_uri = get_default_http_endpoint() @@ -60,6 +71,7 @@ def __init__( self.endpoint_uri = URI(endpoint_uri) self._request_kwargs = request_kwargs or {} + self.exception_retry_configuration = exception_retry_configuration if session: cache_and_return_session(self.endpoint_uri, session) @@ -82,14 +94,41 @@ def get_request_headers(self) -> Dict[str, str]: "User-Agent": construct_user_agent(str(type(self))), } + def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes: + """ + If exception_retry_configuration is set, retry on failure; otherwise, make + the request without retrying. + """ + if ( + self.exception_retry_configuration is not None + and check_if_retry_on_failure( + method, self.exception_retry_configuration.method_allowlist + ) + ): + for i in range(self.exception_retry_configuration.retries): + try: + return make_post_request( + self.endpoint_uri, request_data, **self.get_request_kwargs() + ) + except tuple(self.exception_retry_configuration.errors) as e: + if i < self.exception_retry_configuration.retries - 1: + time.sleep(self.exception_retry_configuration.backoff_factor) + continue + else: + raise e + return None + else: + return make_post_request( + self.endpoint_uri, request_data, **self.get_request_kwargs() + ) + + @handle_request_caching def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: self.logger.debug( f"Making request HTTP. URI: {self.endpoint_uri}, Method: {method}" ) request_data = self.encode_rpc_request(method, params) - raw_response = make_post_request( - self.endpoint_uri, request_data, **self.get_request_kwargs() - ) + raw_response = self._make_request(method, request_data) response = self.decode_rpc_response(raw_response) self.logger.debug( f"Getting response HTTP. URI: {self.endpoint_uri}, " diff --git a/web3/providers/rpc/utils.py b/web3/providers/rpc/utils.py new file mode 100644 index 0000000000..e20fbc44f8 --- /dev/null +++ b/web3/providers/rpc/utils.py @@ -0,0 +1,110 @@ +from typing import ( + Sequence, + Type, +) + +from pydantic import ( + BaseModel, +) +import requests + +from web3.types import ( + RPCEndpoint, +) + +REQUEST_RETRY_ALLOWLIST = [ + "admin", + "miner", + "net", + "txpool", + "testing", + "evm", + "eth_protocolVersion", + "eth_syncing", + "eth_coinbase", + "eth_mining", + "eth_hashrate", + "eth_chainId", + "eth_gasPrice", + "eth_accounts", + "eth_blockNumber", + "eth_getBalance", + "eth_getStorageAt", + "eth_getProof", + "eth_getCode", + "eth_getBlockByNumber", + "eth_getBlockByHash", + "eth_getBlockTransactionCountByNumber", + "eth_getBlockTransactionCountByHash", + "eth_getUncleCountByBlockNumber", + "eth_getUncleCountByBlockHash", + "eth_getTransactionByHash", + "eth_getTransactionByBlockHashAndIndex", + "eth_getTransactionByBlockNumberAndIndex", + "eth_getTransactionReceipt", + "eth_getTransactionCount", + "eth_getRawTransactionByHash", + "eth_call", + "eth_estimateGas", + "eth_createAccessList", + "eth_maxPriorityFeePerGas", + "eth_newBlockFilter", + "eth_newPendingTransactionFilter", + "eth_newFilter", + "eth_getFilterChanges", + "eth_getFilterLogs", + "eth_getLogs", + "eth_uninstallFilter", + "eth_getCompilers", + "eth_getWork", + "eth_sign", + "eth_signTypedData", + "eth_sendRawTransaction", + "personal_importRawKey", + "personal_newAccount", + "personal_listAccounts", + "personal_listWallets", + "personal_lockAccount", + "personal_unlockAccount", + "personal_ecRecover", + "personal_sign", + "personal_signTypedData", +] + + +def check_if_retry_on_failure( + method: RPCEndpoint, + allowlist: Sequence[str] = None, +) -> bool: + if allowlist is None: + allowlist = REQUEST_RETRY_ALLOWLIST + + if method in allowlist or method.split("_")[0] in allowlist: + return True + else: + return False + + +class ExceptionRetryConfiguration(BaseModel): + errors: Sequence[Type[BaseException]] + retries: int + backoff_factor: float + method_allowlist: Sequence[str] + + def __init__( + self, + errors: Sequence[Type[BaseException]] = ( + ConnectionError, + requests.HTTPError, + requests.Timeout, + ), + retries: int = 5, + backoff_factor: float = 0.5, + method_allowlist: Sequence[str] = None, + ): + super().__init__( + errors=errors, + retries=retries, + backoff_factor=backoff_factor, + method_allowlist=method_allowlist or REQUEST_RETRY_ALLOWLIST, + ) diff --git a/web3/providers/websocket/request_processor.py b/web3/providers/websocket/request_processor.py index e331e41012..1b938ab597 100644 --- a/web3/providers/websocket/request_processor.py +++ b/web3/providers/websocket/request_processor.py @@ -65,7 +65,19 @@ def cache_request_information( method: RPCEndpoint, params: Any, response_formatters: Tuple[Callable[..., Any], ...], - ) -> str: + ) -> Optional[str]: + cached_requests_key = generate_cache_key((method, params)) + if cached_requests_key in self._provider._request_cache._data: + cached_response = self._provider._request_cache._data[cached_requests_key] + cached_response_id = cached_response.get("id") + cache_key = generate_cache_key(cached_response_id) + if cache_key in self._request_information_cache: + self._provider.logger.debug( + "This is a cached request, not caching request info because it is " + f"not unique:\n method={method},\n params={params}" + ) + return None + # copy the request counter and find the next request id without incrementing # since this is done when / if the request is successfully sent request_id = next(copy(self._provider.request_counter)) @@ -147,11 +159,20 @@ def get_request_information_for_response( else: # retrieve the request info from the cache using the request id cache_key = generate_cache_key(response["id"]) - request_info = ( - # pop the request info from the cache since we don't need to keep it - # this keeps the cache size bounded - self.pop_cached_request_information(cache_key) - ) + if response in self._provider._request_cache._data.values(): + request_info = ( + # don't pop the request info from the cache, since we need to keep + # it to process future responses + # i.e. request information remains in the cache + self._request_information_cache.get_cache_entry(cache_key) + ) + else: + request_info = ( + # pop the request info from the cache since we don't need to keep it + # this keeps the cache size bounded + self.pop_cached_request_information(cache_key) + ) + if ( request_info is not None and request_info.method == "eth_unsubscribe" diff --git a/web3/providers/websocket/websocket_v2.py b/web3/providers/websocket/websocket_v2.py index 949c6b0c68..3189c6998a 100644 --- a/web3/providers/websocket/websocket_v2.py +++ b/web3/providers/websocket/websocket_v2.py @@ -23,6 +23,7 @@ ) from web3._utils.caching import ( + async_handle_request_caching, generate_cache_key, ) from web3.exceptions import ( @@ -147,6 +148,7 @@ async def disconnect(self) -> None: self._request_processor.clear_caches() + @async_handle_request_caching async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: request_data = self.encode_rpc_request(method, params) diff --git a/web3/tools/benchmark/main.py b/web3/tools/benchmark/main.py index 8c40e12cee..bc96416970 100644 --- a/web3/tools/benchmark/main.py +++ b/web3/tools/benchmark/main.py @@ -24,8 +24,6 @@ Web3, ) from web3.middleware import ( - async_buffered_gas_estimate_middleware, - async_gas_price_strategy_middleware, buffered_gas_estimate_middleware, gas_price_strategy_middleware, ) @@ -74,10 +72,7 @@ async def build_async_w3_http(endpoint_uri: str) -> AsyncWeb3: await wait_for_aiohttp(endpoint_uri) _w3 = AsyncWeb3( AsyncHTTPProvider(endpoint_uri), - middlewares=[ - async_gas_price_strategy_middleware, - async_buffered_gas_estimate_middleware, - ], + middlewares=[gas_price_strategy_middleware, buffered_gas_estimate_middleware], ) return _w3 diff --git a/web3/types.py b/web3/types.py index 8dfc1fbdf3..413d0f0af3 100644 --- a/web3/types.py +++ b/web3/types.py @@ -33,9 +33,6 @@ FallbackFn, ReceiveFn, ) -from web3.datastructures import ( - NamedElementOnion, -) if TYPE_CHECKING: from web3.contract.async_contract import AsyncContractFunction # noqa: F401 @@ -222,7 +219,7 @@ class BlockData(TypedDict, total=False): withdrawals: Sequence[WithdrawalData] withdrawalsRoot: HexBytes - # geth_poa_middleware replaces extraData w/ proofOfAuthorityData + # extradata_to_poa_middleware replaces extraData w/ proofOfAuthorityData proofOfAuthorityData: HexBytes @@ -316,15 +313,8 @@ class CreateAccessListResponse(TypedDict): gasUsed: int -Middleware = Callable[[Callable[[RPCEndpoint, Any], RPCResponse], "Web3"], Any] -AsyncMiddlewareCoroutine = Callable[ - [RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse] -] -AsyncMiddleware = Callable[ - [Callable[[RPCEndpoint, Any], RPCResponse], "AsyncWeb3"], Any -] -MiddlewareOnion = NamedElementOnion[str, Middleware] -AsyncMiddlewareOnion = NamedElementOnion[str, AsyncMiddleware] +MakeRequestFn = Callable[[RPCEndpoint, Any], RPCResponse] +AsyncMakeRequestFn = Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]] class FormattersDict(TypedDict, total=False):