From 87f6908d0353bd65efa47b7e603083463e0a6d94 Mon Sep 17 00:00:00 2001 From: fselmo Date: Tue, 7 Nov 2023 18:17:53 -0700 Subject: [PATCH 01/24] Initial commit for middleware refactor - Refactor middlewares as classes with request processors and response processors - Apply this refactor to some existing middlewares and test - gas_price_strategy_middleware - validation / formatting middleware - attrdict_middleware - (incomplete) ens_name_to_address_middleware - Refactor the exception_retry_middleware to be a configuration on the http providers --- .gitignore | 1 + docs/conf.py | 9 +- ens/utils.py | 25 +-- setup.py | 1 + tests/core/eth-module/test_poa.py | 6 +- .../middleware/test_attrdict_middleware.py | 3 +- web3/__init__.py | 4 +- web3/auto/gethdev.py | 4 +- web3/main.py | 4 +- web3/manager.py | 74 +++---- web3/middleware/__init__.py | 76 +------ web3/middleware/abi.py | 10 +- web3/middleware/attrdict.py | 73 +++---- web3/middleware/base.py | 41 ++++ web3/middleware/cache.py | 14 ++ web3/middleware/exception_retry_request.py | 188 ---------------- web3/middleware/formatting.py | 201 +++++++++--------- web3/middleware/gas_price_strategy.py | 50 ++--- web3/middleware/geth_poa.py | 29 +-- web3/middleware/names.py | 50 +++-- .../normalize_request_parameters.py | 4 +- web3/middleware/pythonic.py | 4 +- web3/middleware/validation.py | 17 +- web3/providers/__init__.py | 2 +- web3/providers/async_base.py | 33 +-- web3/providers/base.py | 37 +--- web3/providers/eth_tester/main.py | 10 +- web3/providers/eth_tester/middleware.py | 21 +- web3/providers/rpc/__init__.py | 6 + web3/providers/{ => rpc}/async_rpc.py | 54 ++++- web3/providers/{ => rpc}/rpc.py | 47 +++- web3/providers/rpc/utils.py | 92 ++++++++ 32 files changed, 511 insertions(+), 679 deletions(-) create mode 100644 web3/middleware/base.py create mode 100644 web3/providers/rpc/__init__.py rename web3/providers/{ => rpc}/async_rpc.py (58%) rename web3/providers/{ => rpc}/rpc.py (58%) create mode 100644 web3/providers/rpc/utils.py diff --git a/.gitignore b/.gitignore index ed1b949468..87e55458ea 100644 --- a/.gitignore +++ b/.gitignore @@ -49,6 +49,7 @@ docs/web3.gas_strategies.rst docs/web3.middleware.rst docs/web3.providers.eth_tester.rst docs/web3.providers.rst +docs/web3.providers.rpc.rst docs/web3.providers.websocket.rst docs/web3.rst docs/web3.scripts.release.rst diff --git a/docs/conf.py b/docs/conf.py index 78f0cf94c7..210918e4f6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -86,6 +86,7 @@ "web3.gas_strategies.rst", "web3.middleware.rst", "web3.providers.rst", + "web3.providers.rpc.rst", "web3.providers.websocket.rst", "web3.providers.eth_tester.rst", "web3.scripts.*", @@ -224,7 +225,13 @@ # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ - ("index", "Populus.tex", "Populus Documentation", "The Ethereum Foundation", "manual"), + ( + "index", + "Populus.tex", + "Populus Documentation", + "The Ethereum Foundation", + "manual", + ), ] # The name of an image file (relative to this directory) to place at the top of diff --git a/ens/utils.py b/ens/utils.py index fd5e81fd44..d2da542f0d 100644 --- a/ens/utils.py +++ b/ens/utils.py @@ -101,16 +101,15 @@ def init_web3( def customize_web3(w3: "_Web3") -> "_Web3": from web3.middleware import ( - make_stalecheck_middleware, + StaleCheckMiddleware, ) if w3.middleware_onion.get("name_to_address"): w3.middleware_onion.remove("name_to_address") if not w3.middleware_onion.get("stalecheck"): - w3.middleware_onion.add( - make_stalecheck_middleware(ACCEPTABLE_STALE_HOURS * 3600), name="stalecheck" - ) + stalecheck_middleware = StaleCheckMiddleware(ACCEPTABLE_STALE_HOURS * 3600) + w3.middleware_onion.add(stalecheck_middleware, name="stalecheck") return w3 @@ -307,6 +306,9 @@ def init_async_web3( from web3.eth import ( AsyncEth as AsyncEthMain, ) + from web3.middleware import ( + StaleCheckMiddleware, + ) middlewares = list(middlewares) for i, (middleware, name) in enumerate(middlewares): @@ -314,7 +316,9 @@ def init_async_web3( middlewares.pop(i) if "stalecheck" not in (name for mw, name in middlewares): - middlewares.append((_async_ens_stalecheck_middleware, "stalecheck")) + middlewares.append( + (StaleCheckMiddleware(ACCEPTABLE_STALE_HOURS * 3600), "stalecheck") + ) if provider is default: async_w3 = AsyncWeb3Main( @@ -329,14 +333,3 @@ def init_async_web3( ) return async_w3 - - -async def _async_ens_stalecheck_middleware( - make_request: Callable[["RPCEndpoint", Any], Any], w3: "AsyncWeb3" -) -> "Middleware": - from web3.middleware import ( - async_make_stalecheck_middleware, - ) - - middleware = await async_make_stalecheck_middleware(ACCEPTABLE_STALE_HOURS * 3600) - return await middleware(make_request, w3) diff --git a/setup.py b/setup.py index 45278eccf8..264e0ac581 100644 --- a/setup.py +++ b/setup.py @@ -78,6 +78,7 @@ "jsonschema>=4.0.0", "lru-dict>=1.1.6,<1.3.0", "protobuf>=4.21.6", + "pydantic>=2.4.0", "pywin32>=223;platform_system=='Windows'", "requests>=2.16.0", "typing-extensions>=4.0.1", diff --git a/tests/core/eth-module/test_poa.py b/tests/core/eth-module/test_poa.py index 45884066cb..a9481b3245 100644 --- a/tests/core/eth-module/test_poa.py +++ b/tests/core/eth-module/test_poa.py @@ -6,7 +6,7 @@ ) from web3.middleware import ( construct_fixture_middleware, - geth_poa_middleware, + extradata_to_poa, ) @@ -39,7 +39,7 @@ def test_geth_proof_of_authority(w3): "eth_getBlockByNumber": {"extraData": "0x" + "ff" * 33}, } ) - w3.middleware_onion.inject(geth_poa_middleware, layer=0) + w3.middleware_onion.inject(extradata_to_poa, layer=0) w3.middleware_onion.inject(return_block_with_long_extra_data, layer=0) block = w3.eth.get_block("latest") assert "extraData" not in block @@ -52,7 +52,7 @@ def test_returns_none_response(w3): "eth_getBlockByNumber": None, } ) - w3.middleware_onion.inject(geth_poa_middleware, layer=0) + w3.middleware_onion.inject(extradata_to_poa, layer=0) w3.middleware_onion.inject(return_none_response, layer=0) with pytest.raises(BlockNotFound): w3.eth.get_block(100000000000) diff --git a/tests/core/middleware/test_attrdict_middleware.py b/tests/core/middleware/test_attrdict_middleware.py index 21d6de8717..0a36654fb3 100644 --- a/tests/core/middleware/test_attrdict_middleware.py +++ b/tests/core/middleware/test_attrdict_middleware.py @@ -9,7 +9,6 @@ AttributeDict, ) from web3.middleware import ( - async_attrdict_middleware, async_construct_result_generator_middleware, attrdict_middleware, construct_result_generator_middleware, @@ -100,7 +99,7 @@ def test_no_attrdict_middleware_does_not_convert_dicts_to_attrdict(): @pytest.mark.asyncio async def test_async_attrdict_middleware_default_for_async_ethereum_tester_provider(): async_w3 = AsyncWeb3(AsyncEthereumTesterProvider()) - assert async_w3.middleware_onion.get("attrdict") == async_attrdict_middleware + assert async_w3.middleware_onion.get("attrdict") == attrdict_middleware @pytest.mark.asyncio diff --git a/web3/__init__.py b/web3/__init__.py index aa2ba141f6..4406207076 100644 --- a/web3/__init__.py +++ b/web3/__init__.py @@ -15,9 +15,6 @@ AsyncWeb3, Web3, ) -from web3.providers.async_rpc import ( # noqa: E402 - AsyncHTTPProvider, -) from web3.providers.eth_tester import ( # noqa: E402 EthereumTesterProvider, ) @@ -25,6 +22,7 @@ IPCProvider, ) from web3.providers.rpc import ( # noqa: E402 + AsyncHTTPProvider, HTTPProvider, ) from web3.providers.websocket import ( # noqa: E402 diff --git a/web3/auto/gethdev.py b/web3/auto/gethdev.py index ebba276991..fc28e21999 100644 --- a/web3/auto/gethdev.py +++ b/web3/auto/gethdev.py @@ -3,11 +3,11 @@ Web3, ) from web3.middleware import ( - geth_poa_middleware, + extradata_to_poa, ) from web3.providers.ipc import ( get_dev_ipc_path, ) w3 = Web3(IPCProvider(get_dev_ipc_path())) -w3.middleware_onion.inject(geth_poa_middleware, layer=0) +w3.middleware_onion.inject(extradata_to_poa, layer=0) diff --git a/web3/main.py b/web3/main.py index 6c9ab7645b..61461b319c 100644 --- a/web3/main.py +++ b/web3/main.py @@ -119,13 +119,11 @@ from web3.providers.ipc import ( IPCProvider, ) -from web3.providers.async_rpc import ( - AsyncHTTPProvider, -) from web3.providers.persistent import ( PersistentConnectionProvider, ) from web3.providers.rpc import ( + AsyncHTTPProvider, HTTPProvider, ) from web3.providers.websocket import ( diff --git a/web3/manager.py b/web3/manager.py index c0adfcca99..43d1175864 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -37,16 +37,9 @@ MethodUnavailable, ) from web3.middleware import ( - abi_middleware, - async_attrdict_middleware, - async_buffered_gas_estimate_middleware, - async_gas_price_strategy_middleware, - async_name_to_address_middleware, - async_validation_middleware, attrdict_middleware, - buffered_gas_estimate_middleware, gas_price_strategy_middleware, - name_to_address_middleware, + ens_name_to_address_middleware, validation_middleware, ) from web3.module import ( @@ -77,6 +70,9 @@ from web3.providers.websocket.request_processor import ( # noqa: F401 RequestProcessor, ) + from web3.middleware.base import ( # noqa: F401 + Web3Middleware, + ) NULL_RESPONSES = [None, HexBytes("0x"), "0x"] @@ -143,11 +139,7 @@ def __init__( self.provider = provider if middlewares is None: - middlewares = ( - self.async_default_middlewares() - if self.provider.is_async - else self.default_middlewares(cast("Web3", w3)) - ) + middlewares = self.default_middlewares(w3) self.middleware_onion = NamedElementOnion(middlewares) @@ -169,7 +161,7 @@ def provider(self, provider: Union["BaseProvider", "AsyncBaseProvider"]) -> None self._provider = provider @staticmethod - def default_middlewares(w3: "Web3") -> List[Tuple[Middleware, str]]: + def default_middlewares(w3: "Web3") -> List[Tuple["Web3Middleware", str]]: """ List the default middlewares for the request manager. Leaving w3 unspecified will prevent the middleware from resolving names. @@ -177,25 +169,10 @@ def default_middlewares(w3: "Web3") -> List[Tuple[Middleware, str]]: """ return [ (gas_price_strategy_middleware, "gas_price_strategy"), - (name_to_address_middleware(w3), "name_to_address"), + # (ens_name_to_address_middleware, "name_to_address"), (attrdict_middleware, "attrdict"), (validation_middleware, "validation"), - (abi_middleware, "abi"), - (buffered_gas_estimate_middleware, "gas_estimate"), - ] - - @staticmethod - def async_default_middlewares() -> List[Tuple[AsyncMiddleware, str]]: - """ - List the default async middlewares for the request manager. - Documentation should remain in sync with these defaults. - """ - return [ - (async_gas_price_strategy_middleware, "gas_price_strategy"), - (async_name_to_address_middleware, "name_to_address"), - (async_attrdict_middleware, "attrdict"), - (async_validation_middleware, "validation"), - (async_buffered_gas_estimate_middleware, "gas_estimate"), + # (async_buffered_gas_estimate_middleware, "gas_estimate"), ] # @@ -204,23 +181,36 @@ def async_default_middlewares() -> List[Tuple[AsyncMiddleware, str]]: def _make_request( self, method: Union[RPCEndpoint, Callable[..., RPCEndpoint]], params: Any ) -> RPCResponse: - provider = cast("BaseProvider", self.provider) - request_func = provider.request_func( - cast("Web3", self.w3), cast(MiddlewareOnion, self.middleware_onion) - ) + """ + 1. Pipe the request params through the middleware stack + 2. Make the request using the provider + 3. Pipe the raw response through the middleware stack + """ self.logger.debug(f"Making request. Method: {method}") - return request_func(method, params) + for middleware, _name in self.middleware_onion.middlewares: + params = middleware.process_request_params(self.w3, method, params) + response = self.provider.make_request(method, params) + for middleware, _name in reversed(self.middleware_onion.middlewares): + response = middleware.process_response(self.w3, method, response) + return response async def _coro_make_request( self, method: Union[RPCEndpoint, Callable[..., RPCEndpoint]], params: Any ) -> RPCResponse: - provider = cast("AsyncBaseProvider", self.provider) - request_func = await provider.request_func( - cast("AsyncWeb3", self.w3), - cast(AsyncMiddlewareOnion, self.middleware_onion), - ) + """ + 1. Pipe the request params through the middleware stack + 2. Make the request using the provider + 3. Pipe the raw response through the middleware stack + """ self.logger.debug(f"Making request. Method: {method}") - return await request_func(method, params) + for middleware, _name in self.middleware_onion.middlewares: + params = await middleware.async_process_request_params( + self.w3, method, params + ) + response = await self.provider.make_request(method, params) + for middleware, _name in reversed(self.middleware_onion.middlewares): + response = await middleware.async_process_response(self.w3, response) + return response # # formatted_response parses and validates JSON-RPC responses for expected diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index 39697c2720..5a5fd489c1 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -1,19 +1,3 @@ -import functools -from typing import ( - Coroutine, - TYPE_CHECKING, - Any, - Callable, - Sequence, -) - -from web3.types import ( - AsyncMiddleware, - Middleware, - RPCEndpoint, - RPCResponse, -) - from .abi import ( abi_middleware, ) @@ -22,7 +6,6 @@ async_construct_simple_cache_middleware, ) from .attrdict import ( - async_attrdict_middleware, attrdict_middleware, ) from .buffered_gas_estimate import ( @@ -40,10 +23,6 @@ from .exception_handling import ( construct_exception_handler_middleware, ) -from .exception_retry_request import ( - async_http_retry_request_middleware, - http_retry_request_middleware, -) from .filter import ( async_local_filter_middleware, local_filter_middleware, @@ -55,20 +34,14 @@ construct_fixture_middleware, construct_result_generator_middleware, ) -from .formatting import ( - construct_formatting_middleware, -) from .gas_price_strategy import ( - async_gas_price_strategy_middleware, gas_price_strategy_middleware, ) from .geth_poa import ( - async_geth_poa_middleware, - geth_poa_middleware, + extradata_to_poa_middleware, ) from .names import ( - async_name_to_address_middleware, - name_to_address_middleware, + ens_name_to_address_middleware, ) from .normalize_request_parameters import ( request_parameter_normalizer, @@ -84,50 +57,5 @@ make_stalecheck_middleware, ) from .validation import ( - async_validation_middleware, validation_middleware, ) - -if TYPE_CHECKING: - from web3 import AsyncWeb3, Web3 - - -def combine_middlewares( - middlewares: Sequence[Middleware], - w3: "Web3", - provider_request_fn: Callable[[RPCEndpoint, Any], Any], -) -> Callable[..., RPCResponse]: - """ - Returns a callable function which will call the provider.provider_request - function wrapped with all of the middlewares. - """ - return functools.reduce( - lambda request_fn, middleware: middleware(request_fn, w3), - reversed(middlewares), - provider_request_fn, - ) - - -async def async_combine_middlewares( - middlewares: Sequence[AsyncMiddleware], - async_w3: "AsyncWeb3", - provider_request_fn: Callable[[RPCEndpoint, Any], Any], -) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: - """ - Returns a callable function which will call the provider.provider_request - function wrapped with all of the middlewares. - """ - accumulator_fn = provider_request_fn - for middleware in reversed(middlewares): - accumulator_fn = await construct_middleware( - middleware, accumulator_fn, async_w3 - ) - return accumulator_fn - - -async def construct_middleware( - async_middleware: AsyncMiddleware, - fn: Callable[..., RPCResponse], - async_w3: "AsyncWeb3", -) -> Callable[[RPCEndpoint, Any], RPCResponse]: - return await async_middleware(fn, async_w3) diff --git a/web3/middleware/abi.py b/web3/middleware/abi.py index 97e2063df1..65290ce87b 100644 --- a/web3/middleware/abi.py +++ b/web3/middleware/abi.py @@ -1,11 +1,9 @@ from web3._utils.method_formatters import ( ABI_REQUEST_FORMATTERS, ) - -from .formatting import ( - construct_formatting_middleware, +from web3.middleware.formatting import ( + FormattingMiddleware, ) -abi_middleware = construct_formatting_middleware( - request_formatters=ABI_REQUEST_FORMATTERS -) + +abi_middleware = FormattingMiddleware(request_formatters=ABI_REQUEST_FORMATTERS) diff --git a/web3/middleware/attrdict.py b/web3/middleware/attrdict.py index ea0c152d5d..e433071243 100644 --- a/web3/middleware/attrdict.py +++ b/web3/middleware/attrdict.py @@ -1,3 +1,4 @@ +from abc import ABC from typing import ( TYPE_CHECKING, Any, @@ -13,8 +14,10 @@ from web3.datastructures import ( AttributeDict, ) +from web3.middleware.base import ( + Web3Middleware, +) from web3.types import ( - AsyncMiddlewareCoroutine, RPCEndpoint, RPCResponse, ) @@ -29,9 +32,25 @@ ) -def attrdict_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], _w3: "Web3" -) -> Callable[[RPCEndpoint, Any], RPCResponse]: +def _handle_async_response(response: RPCResponse) -> RPCResponse: + if "result" in response: + return assoc(response, "result", AttributeDict.recursive(response["result"])) + elif "params" in response and "result" in response["params"]: + # this is a subscription response + return assoc( + response, + "params", + assoc( + response["params"], + "result", + AttributeDict.recursive(response["params"]["result"]), + ), + ) + else: + return response + + +class AttributeDictMiddleware(Web3Middleware, ABC): """ Converts any result which is a dictionary into an `AttributeDict`. @@ -39,9 +58,9 @@ def attrdict_middleware( (e.g. my_attribute_dict.property1) will not preserve typing. """ - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - response = make_request(method, params) - + def process_response( + self, w3: "Web3", method: "RPCEndpoint", response: "RPCResponse" + ) -> Any: if "result" in response: return assoc( response, "result", AttributeDict.recursive(response["result"]) @@ -49,24 +68,11 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: else: return response - return middleware - + # -- async -- # -# --- async --- # - - -async def async_attrdict_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" -) -> AsyncMiddlewareCoroutine: - """ - Converts any result which is a dictionary into an `AttributeDict`. - - Note: Accessing `AttributeDict` properties via attribute - (e.g. my_attribute_dict.property1) will not preserve typing. - """ - - async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: - response = await make_request(method, params) + async def async_process_response( + self, async_w3: "AsyncWeb3", method: "RPCEndpoint", response: "RPCResponse" + ) -> Any: if async_w3.provider.has_persistent_connection: # asynchronous response processing provider = cast("PersistentConnectionProvider", async_w3.provider) @@ -77,22 +83,5 @@ async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: else: return _handle_async_response(response) - return middleware - -def _handle_async_response(response: RPCResponse) -> RPCResponse: - if "result" in response: - return assoc(response, "result", AttributeDict.recursive(response["result"])) - elif "params" in response and "result" in response["params"]: - # this is a subscription response - return assoc( - response, - "params", - assoc( - response["params"], - "result", - AttributeDict.recursive(response["params"]["result"]), - ), - ) - else: - return response +attrdict_middleware = AttributeDictMiddleware() diff --git a/web3/middleware/base.py b/web3/middleware/base.py new file mode 100644 index 0000000000..49b83a640b --- /dev/null +++ b/web3/middleware/base.py @@ -0,0 +1,41 @@ +from abc import abstractmethod +from typing import Any, TYPE_CHECKING + + +if TYPE_CHECKING: + from web3 import ( + AsyncWeb3, + Web3, + ) + from web3.types import ( + RPCEndpoint, + RPCResponse, + ) + + +class Web3Middleware: + @abstractmethod + def process_request_params( + self, w3: "Web3", method: "RPCEndpoint", params: Any + ) -> Any: + return params + + @abstractmethod + def process_response( + self, w3: "Web3", method: "RPCEndpoint", response: "RPCResponse" + ) -> Any: + return response + + # -- async -- # + + @abstractmethod + async def async_process_request_params( + self, async_w3: "AsyncWeb3", method: "RPCEndpoint", params: Any + ) -> Any: + return params + + @abstractmethod + async def async_process_response( + self, async_w3: "AsyncWeb3", method: "RPCEndpoint", response: "RPCResponse" + ) -> Any: + return response diff --git a/web3/middleware/cache.py b/web3/middleware/cache.py index 189cf706cb..39e871f6a2 100644 --- a/web3/middleware/cache.py +++ b/web3/middleware/cache.py @@ -23,6 +23,7 @@ Literal, TypedDict, ) +from web3.middleware.base import Web3Middleware from web3.types import ( BlockData, BlockNumber, @@ -122,6 +123,19 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: _simple_cache_middleware = construct_simple_cache_middleware() +class SimpleCacheMiddleware(Web3Middleware): + """ + Constructs a middleware which caches responses based on the request + ``method`` and ``params`` + + :param cache: A ``SimpleCache`` class. + :param rpc_whitelist: A set of RPC methods which may have their responses cached. + :param should_cache_fn: A callable which accepts ``method`` ``params`` and + ``response`` and returns a boolean as to whether the response should be + cached. + """ + + TIME_BASED_CACHE_RPC_WHITELIST = cast( Set[RPCEndpoint], { diff --git a/web3/middleware/exception_retry_request.py b/web3/middleware/exception_retry_request.py index aa216c4e28..e69de29bb2 100644 --- a/web3/middleware/exception_retry_request.py +++ b/web3/middleware/exception_retry_request.py @@ -1,188 +0,0 @@ -import asyncio -import time -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Collection, - List, - Optional, - Type, -) - -import aiohttp -from requests.exceptions import ( - ConnectionError, - HTTPError, - Timeout, - TooManyRedirects, -) - -from web3.types import ( - AsyncMiddlewareCoroutine, - RPCEndpoint, - RPCResponse, -) - -if TYPE_CHECKING: - from web3 import ( # noqa: F401 - AsyncWeb3, - Web3, - ) - -DEFAULT_ALLOWLIST = [ - "admin", - "miner", - "net", - "txpool", - "testing", - "evm", - "eth_protocolVersion", - "eth_syncing", - "eth_coinbase", - "eth_mining", - "eth_hashrate", - "eth_chainId", - "eth_gasPrice", - "eth_accounts", - "eth_blockNumber", - "eth_getBalance", - "eth_getStorageAt", - "eth_getProof", - "eth_getCode", - "eth_getBlockByNumber", - "eth_getBlockByHash", - "eth_getBlockTransactionCountByNumber", - "eth_getBlockTransactionCountByHash", - "eth_getUncleCountByBlockNumber", - "eth_getUncleCountByBlockHash", - "eth_getTransactionByHash", - "eth_getTransactionByBlockHashAndIndex", - "eth_getTransactionByBlockNumberAndIndex", - "eth_getTransactionReceipt", - "eth_getTransactionCount", - "eth_getRawTransactionByHash", - "eth_call", - "eth_createAccessList", - "eth_estimateGas", - "eth_maxPriorityFeePerGas", - "eth_newBlockFilter", - "eth_newPendingTransactionFilter", - "eth_newFilter", - "eth_getFilterChanges", - "eth_getFilterLogs", - "eth_getLogs", - "eth_uninstallFilter", - "eth_getCompilers", - "eth_getWork", - "eth_sign", - "eth_signTypedData", - "eth_sendRawTransaction", - "personal_importRawKey", - "personal_newAccount", - "personal_listAccounts", - "personal_listWallets", - "personal_lockAccount", - "personal_unlockAccount", - "personal_ecRecover", - "personal_sign", - "personal_signTypedData", -] - - -def check_if_retry_on_failure( - method: str, allow_list: Optional[List[str]] = None -) -> bool: - if allow_list is None: - allow_list = DEFAULT_ALLOWLIST - - root = method.split("_")[0] - if root in allow_list: - return True - elif method in allow_list: - return True - else: - return False - - -def exception_retry_middleware( - make_request: Callable[[RPCEndpoint, Any], RPCResponse], - _w3: "Web3", - errors: Collection[Type[BaseException]], - retries: int = 5, - backoff_factor: float = 0.3, - allow_list: Optional[List[str]] = None, -) -> Callable[[RPCEndpoint, Any], RPCResponse]: - """ - Creates middleware that retries failed HTTP requests. Is a default - middleware for HTTPProvider. - """ - - def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: - if check_if_retry_on_failure(method, allow_list): - for i in range(retries): - try: - return make_request(method, params) - except tuple(errors): - if i < retries - 1: - time.sleep(backoff_factor) - continue - else: - raise - return None - else: - return make_request(method, params) - - return middleware - - -def http_retry_request_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" -) -> Callable[[RPCEndpoint, Any], Any]: - return exception_retry_middleware( - make_request, w3, (ConnectionError, HTTPError, Timeout, TooManyRedirects) - ) - - -# -- async -- # - - -async def async_exception_retry_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], - _async_w3: "AsyncWeb3", - errors: Collection[Type[BaseException]], - retries: int = 5, - backoff_factor: float = 0.3, - allow_list: Optional[List[str]] = None, -) -> AsyncMiddlewareCoroutine: - """ - Creates middleware that retries failed HTTP requests. - Is a default middleware for AsyncHTTPProvider. - """ - - async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: - if check_if_retry_on_failure(method, allow_list): - for i in range(retries): - try: - return await make_request(method, params) - except tuple(errors): - if i < retries - 1: - await asyncio.sleep(backoff_factor) - continue - else: - raise - return None - else: - return await make_request(method, params) - - return middleware - - -async def async_http_retry_request_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" -) -> Callable[[RPCEndpoint, Any], Any]: - return await async_exception_retry_middleware( - make_request, - async_w3, - (TimeoutError, aiohttp.ClientError), - ) diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index 1fddb9cea4..356dd398c3 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -4,6 +4,7 @@ Callable, Coroutine, Optional, + TypeVar, cast, ) @@ -13,14 +14,14 @@ merge, ) +from web3.middleware.base import ( + Web3Middleware, +) from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareCoroutine, EthSubscriptionParams, Formatters, FormattersDict, Literal, - Middleware, RPCEndpoint, RPCResponse, ) @@ -85,125 +86,113 @@ def _format_response( return response -# --- sync -- # +SYNC_FORMATTERS_BUILDER = Callable[["Web3", RPCEndpoint], FormattersDict] +ASYNC_FORMATTERS_BUILDER = Callable[ + ["AsyncWeb3", RPCEndpoint], Coroutine[Any, Any, FormattersDict] +] -def construct_formatting_middleware( - request_formatters: Optional[Formatters] = None, - result_formatters: Optional[Formatters] = None, - error_formatters: Optional[Formatters] = None, -) -> Middleware: - def ignore_web3_in_standard_formatters( - _w3: "Web3", - _method: RPCEndpoint, - ) -> FormattersDict: - return dict( - request_formatters=request_formatters or {}, - result_formatters=result_formatters or {}, - error_formatters=error_formatters or {}, - ) +class FormattingMiddleware(Web3Middleware): + def __init__( + self, + # formatters option: + request_formatters: Optional[Formatters] = None, + result_formatters: Optional[Formatters] = None, + error_formatters: Optional[Formatters] = None, + # formatters builder option: + sync_formatters_builder: Optional[SYNC_FORMATTERS_BUILDER] = None, + async_formatters_builder: Optional[ASYNC_FORMATTERS_BUILDER] = None, + ): + # if not both sync and async formatters are specified, raise error + if ( + sync_formatters_builder is None and async_formatters_builder is not None + ) or (sync_formatters_builder is not None and async_formatters_builder is None): + raise ValueError( + "Must specify both sync_formatters_builder and async_formatters_builder" + ) - return construct_web3_formatting_middleware(ignore_web3_in_standard_formatters) + if sync_formatters_builder is not None and async_formatters_builder is not None: + if ( + request_formatters is not None + or result_formatters is not None + or error_formatters is not None + ): + raise ValueError( + "Cannot specify formatters_builder and formatters at the same time" + ) + self.request_formatters = request_formatters or {} + self.result_formatters = result_formatters or {} + self.error_formatters = error_formatters or {} + self.sync_formatters_builder = sync_formatters_builder + self.async_formatters_builder = async_formatters_builder -def construct_web3_formatting_middleware( - web3_formatters_builder: Callable[["Web3", RPCEndpoint], FormattersDict], -) -> Middleware: - def formatter_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], - w3: "Web3", - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + def process_request_params( + self, w3: "Web3", method: "RPCEndpoint", params: Any + ) -> Any: + if self.sync_formatters_builder is not None: formatters = merge( FORMATTER_DEFAULTS, - web3_formatters_builder(w3, method), - ) - request_formatters = formatters.pop("request_formatters") - - if method in request_formatters: - formatter = request_formatters[method] - params = formatter(params) - response = make_request(method, params) - - return _apply_response_formatters( - method, - formatters["result_formatters"], - formatters["error_formatters"], - response, + self.sync_formatters_builder(w3, method), ) + self.request_formatters = formatters.pop("request_formatters") - return middleware - - return formatter_middleware - - -# --- async --- # + if method in self.request_formatters: + formatter = self.request_formatters[method] + params = formatter(params) + return params -async def async_construct_formatting_middleware( - request_formatters: Optional[Formatters] = None, - result_formatters: Optional[Formatters] = None, - error_formatters: Optional[Formatters] = None, -) -> AsyncMiddleware: - async def ignore_web3_in_standard_formatters( - _async_w3: "AsyncWeb3", - _method: RPCEndpoint, - ) -> FormattersDict: - return dict( - request_formatters=request_formatters or {}, - result_formatters=result_formatters or {}, - error_formatters=error_formatters or {}, + def process_response( + self, w3: "Web3", method: RPCEndpoint, response: "RPCResponse" + ) -> Any: + if self.sync_formatters_builder is not None: + formatters = merge( + FORMATTER_DEFAULTS, + self.sync_formatters_builder(w3, method), + ) + self.result_formatters = formatters["result_formatters"] + self.error_formatters = formatters["error_formatters"] + + return _apply_response_formatters( + method, + self.result_formatters, + self.error_formatters, + response, ) - return await async_construct_web3_formatting_middleware( - ignore_web3_in_standard_formatters - ) - + # -- async -- # -async def async_construct_web3_formatting_middleware( - async_web3_formatters_builder: Callable[ - ["AsyncWeb3", RPCEndpoint], Coroutine[Any, Any, FormattersDict] - ] -) -> Callable[ - [Callable[[RPCEndpoint, Any], Any], "AsyncWeb3"], - Coroutine[Any, Any, AsyncMiddlewareCoroutine], -]: - async def formatter_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], - async_w3: "AsyncWeb3", - ) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: + async def async_process_request_params( + self, async_w3: "AsyncWeb3", method: "RPCEndpoint", params: Any + ) -> Any: + if self.async_formatters_builder is not None: formatters = merge( FORMATTER_DEFAULTS, - await async_web3_formatters_builder(async_w3, method), + await self.async_formatters_builder(async_w3, method), ) - request_formatters = formatters.pop("request_formatters") - - if method in request_formatters: - formatter = request_formatters[method] - params = formatter(params) - response = await make_request(method, params) - - if async_w3.provider.has_persistent_connection: - # asynchronous response processing - provider = cast("PersistentConnectionProvider", async_w3.provider) - provider._request_processor.append_middleware_response_processor( - response, - _apply_response_formatters( - method, - formatters["result_formatters"], - formatters["error_formatters"], - ), - ) - return response - else: - return _apply_response_formatters( - method, - formatters["result_formatters"], - formatters["error_formatters"], - response, - ) + self.request_formatters = formatters.pop("request_formatters") + + if method in self.request_formatters: + formatter = self.request_formatters[method] + params = formatter(params) - return middleware + return params - return formatter_middleware + async def async_process_response( + self, async_w3: "AsyncWeb3", method: RPCEndpoint, response: "RPCResponse" + ) -> Any: + if self.async_formatters_builder is not None: + formatters = merge( + FORMATTER_DEFAULTS, + await self.async_formatters_builder(async_w3, method), + ) + self.result_formatters = formatters["result_formatters"] + self.error_formatters = formatters["error_formatters"] + + return _apply_response_formatters( + method, + self.result_formatters, + self.error_formatters, + response, + ) diff --git a/web3/middleware/gas_price_strategy.py b/web3/middleware/gas_price_strategy.py index 1cd5a31f09..ac6e4a8090 100644 --- a/web3/middleware/gas_price_strategy.py +++ b/web3/middleware/gas_price_strategy.py @@ -1,7 +1,7 @@ +from abc import ABC from typing import ( TYPE_CHECKING, Any, - Callable, ) from eth_utils.toolz import ( @@ -23,11 +23,12 @@ InvalidTransaction, TransactionTypeMismatch, ) +from web3.middleware.base import ( + Web3Middleware, +) from web3.types import ( - AsyncMiddlewareCoroutine, BlockData, RPCEndpoint, - RPCResponse, TxParams, Wei, ) @@ -78,18 +79,18 @@ def validate_transaction_params( return transaction -def gas_price_strategy_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" -) -> Callable[[RPCEndpoint, Any], RPCResponse]: +class GasPriceStrategyMiddleware(Web3Middleware, ABC): """ - - Uses a gas price strategy if one is set. This is only supported - for legacy transactions. It is recommended to send dynamic fee - transactions (EIP-1559) whenever possible. + - Uses a gas price strategy if one is set. This is only supported for + legacy transactions. It is recommended to send dynamic fee transactions + (EIP-1559) whenever possible. - Validates transaction params against legacy and dynamic fee txn values. """ - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + def process_request_params( + self, w3: "Web3", method: RPCEndpoint, params: Any + ) -> Any: if method == "eth_sendTransaction": transaction = params[0] generated_gas_price = w3.eth.generate_gas_price(transaction) @@ -97,24 +98,14 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: transaction = validate_transaction_params( transaction, latest_block, generated_gas_price ) - return make_request(method, (transaction,)) - return make_request(method, params) - - return middleware - - -async def async_gas_price_strategy_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" -) -> AsyncMiddlewareCoroutine: - """ - - Uses a gas price strategy if one is set. This is only supported for - legacy transactions. It is recommended to send dynamic fee transactions - (EIP-1559) whenever possible. + return (transaction,) + return params - - Validates transaction params against legacy and dynamic fee txn values. - """ + # -- async -- # - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + async def async_process_request_params( + self, async_w3: "AsyncWeb3", method: RPCEndpoint, params: Any + ) -> Any: if method == "eth_sendTransaction": transaction = params[0] generated_gas_price = async_w3.eth.generate_gas_price(transaction) @@ -122,7 +113,8 @@ async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: transaction = validate_transaction_params( transaction, latest_block, generated_gas_price ) - return await make_request(method, (transaction,)) - return await make_request(method, params) + return (transaction,) + return params + - return middleware +gas_price_strategy_middleware = GasPriceStrategyMiddleware() diff --git a/web3/middleware/geth_poa.py b/web3/middleware/geth_poa.py index dd80431991..e27b493394 100644 --- a/web3/middleware/geth_poa.py +++ b/web3/middleware/geth_poa.py @@ -24,10 +24,7 @@ from web3._utils.rpc_abi import ( RPC, ) -from web3.middleware.formatting import ( - async_construct_formatting_middleware, - construct_formatting_middleware, -) +from web3.middleware.formatting import FormattingMiddleware from web3.types import ( AsyncMiddlewareCoroutine, RPCEndpoint, @@ -56,26 +53,14 @@ geth_poa_cleanup = compose(pythonic_geth_poa, remap_geth_poa_fields) -geth_poa_middleware = construct_formatting_middleware( +extradata_to_poa_middleware = FormattingMiddleware( result_formatters={ RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup), RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup), + RPC.eth_subscribe: apply_formatter_if( + is_not_null, + # original call to eth_subscribe returns a string, needs a dict check + apply_formatter_if(is_dict, geth_poa_cleanup), + ), }, ) - - -async def async_geth_poa_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "AsyncWeb3" -) -> AsyncMiddlewareCoroutine: - middleware = await async_construct_formatting_middleware( - result_formatters={ - RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup), - RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup), - RPC.eth_subscribe: apply_formatter_if( - is_not_null, - # original call to eth_subscribe returns a string, needs a dict check - apply_formatter_if(is_dict, geth_poa_cleanup), - ), - }, - ) - return await middleware(make_request, w3) diff --git a/web3/middleware/names.py b/web3/middleware/names.py index 71a8cb4421..54ca201e15 100644 --- a/web3/middleware/names.py +++ b/web3/middleware/names.py @@ -1,7 +1,7 @@ +from abc import ABC from typing import ( TYPE_CHECKING, Any, - Callable, Dict, Sequence, Union, @@ -20,10 +20,9 @@ abi_request_formatters, ) from web3.types import ( - AsyncMiddlewareCoroutine, - Middleware, RPCEndpoint, ) +from .base import Web3Middleware from .._utils.abi import ( abi_data_tree, @@ -34,7 +33,7 @@ recursive_map, ) from .formatting import ( - construct_formatting_middleware, + FormattingMiddleware, ) if TYPE_CHECKING: @@ -44,18 +43,6 @@ ) -def name_to_address_middleware(w3: "Web3") -> Middleware: - normalizers = [ - abi_ens_resolver(w3), - ] - return construct_formatting_middleware( - request_formatters=abi_request_formatters(normalizers, RPC_ABIS) - ) - - -# -- async -- # - - def _is_logs_subscription_with_optional_args(method: RPCEndpoint, params: Any) -> bool: return method == "eth_subscribe" and len(params) == 2 and params[0] == "logs" @@ -109,11 +96,26 @@ async def async_apply_ens_to_address_conversion( ) -async def async_name_to_address_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], - async_w3: "AsyncWeb3", -) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> Any: +class EnsNameToAddressMiddleware(Web3Middleware, ABC): + def __init__(self) -> None: + # the sync version of this middleware utilizes the FormattingMiddleware + self._formatting_middleware = FormattingMiddleware( + request_formatters=abi_request_formatters( # type: ignore + [abi_ens_resolver], + RPC_ABIS, # type: ignore + ) + ) + + def process_request_params( + self, w3: "Web3", method: "RPCEndpoint", params: Any + ) -> Any: + return self._formatting_middleware.process_request_params(w3, method, params) + + # -- async -- # + + async def async_process_request_params( + self, async_w3: "AsyncWeb3", method: "RPCEndpoint", params: Any + ) -> Any: abi_types_for_method = RPC_ABIS.get(method, None) if abi_types_for_method is not None: @@ -136,6 +138,8 @@ async def middleware(method: RPCEndpoint, params: Any) -> Any: params, abi_types_for_method, ) - return await make_request(method, params) - return middleware + return params + + +ens_name_to_address_middleware = EnsNameToAddressMiddleware() diff --git a/web3/middleware/normalize_request_parameters.py b/web3/middleware/normalize_request_parameters.py index 93926c99a4..66c8618df1 100644 --- a/web3/middleware/normalize_request_parameters.py +++ b/web3/middleware/normalize_request_parameters.py @@ -3,9 +3,9 @@ ) from .formatting import ( - construct_formatting_middleware, + FormattingMiddleware, ) -request_parameter_normalizer = construct_formatting_middleware( +request_parameter_normalizer = FormattingMiddleware( request_formatters=METHOD_NORMALIZERS, ) diff --git a/web3/middleware/pythonic.py b/web3/middleware/pythonic.py index 1ffd253e68..f83ab5568a 100644 --- a/web3/middleware/pythonic.py +++ b/web3/middleware/pythonic.py @@ -3,10 +3,10 @@ PYTHONIC_RESULT_FORMATTERS, ) from web3.middleware.formatting import ( - construct_formatting_middleware, + FormattingMiddleware, ) -pythonic_middleware = construct_formatting_middleware( +pythonic_middleware = FormattingMiddleware( request_formatters=PYTHONIC_REQUEST_FORMATTERS, result_formatters=PYTHONIC_RESULT_FORMATTERS, ) diff --git a/web3/middleware/validation.py b/web3/middleware/validation.py index 9686a86715..99b3660a59 100644 --- a/web3/middleware/validation.py +++ b/web3/middleware/validation.py @@ -33,8 +33,7 @@ Web3ValidationError, ) from web3.middleware.formatting import ( - async_construct_web3_formatting_middleware, - construct_web3_formatting_middleware, + FormattingMiddleware, ) from web3.types import ( AsyncMiddlewareCoroutine, @@ -147,9 +146,6 @@ def build_method_validators(w3: "Web3", method: RPCEndpoint) -> FormattersDict: return _build_formatters_dict(request_formatters) -validation_middleware = construct_web3_formatting_middleware(build_method_validators) - - # -- async --- # @@ -165,10 +161,7 @@ async def async_build_method_validators( return _build_formatters_dict(request_formatters) -async def async_validation_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "AsyncWeb3" -) -> AsyncMiddlewareCoroutine: - middleware = await async_construct_web3_formatting_middleware( - async_build_method_validators - ) - return await middleware(make_request, w3) +validation_middleware = FormattingMiddleware( + sync_formatters_builder=build_method_validators, + async_formatters_builder=async_build_method_validators, +) diff --git a/web3/providers/__init__.py b/web3/providers/__init__.py index 883b88060e..e3a74a4d93 100644 --- a/web3/providers/__init__.py +++ b/web3/providers/__init__.py @@ -1,7 +1,7 @@ from .async_base import ( AsyncBaseProvider, ) -from .async_rpc import ( +from .rpc import ( AsyncHTTPProvider, ) from .base import ( diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py index 30404b6b7c..1a0d1ac73c 100644 --- a/web3/providers/async_base.py +++ b/web3/providers/async_base.py @@ -4,7 +4,6 @@ Any, Callable, Coroutine, - Sequence, Tuple, cast, ) @@ -22,12 +21,9 @@ from web3.exceptions import ( ProviderConnectionError, ) -from web3.middleware import ( - async_combine_middlewares, -) +from web3.middleware.base import Web3Middleware from web3.types import ( AsyncMiddleware, - AsyncMiddlewareOnion, MiddlewareOnion, RPCEndpoint, RPCResponse, @@ -41,7 +37,7 @@ class AsyncBaseProvider: - _middlewares: Tuple[AsyncMiddleware, ...] = () + _middlewares: Tuple[Web3Middleware, ...] = () # a tuple of (all_middlewares, request_func) _request_func_cache: Tuple[ Tuple[AsyncMiddleware, ...], Callable[..., Coroutine[Any, Any, RPCResponse]] @@ -56,7 +52,7 @@ class AsyncBaseProvider: ccip_read_max_redirects: int = 4 @property - def middlewares(self) -> Tuple[AsyncMiddleware, ...]: + def middlewares(self) -> Tuple[Web3Middleware, ...]: return self._middlewares @middlewares.setter @@ -64,29 +60,6 @@ def middlewares(self, values: MiddlewareOnion) -> None: # tuple(values) converts to MiddlewareOnion -> Tuple[Middleware, ...] self._middlewares = tuple(values) # type: ignore - async def request_func( - self, async_w3: "AsyncWeb3", outer_middlewares: AsyncMiddlewareOnion - ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: - # type ignored b/c tuple(MiddlewareOnion) converts to tuple of middlewares - all_middlewares: Tuple[AsyncMiddleware] = tuple(outer_middlewares) + tuple(self.middlewares) # type: ignore # noqa: E501 - - cache_key = self._request_func_cache[0] - if cache_key is None or cache_key != all_middlewares: - self._request_func_cache = ( - all_middlewares, - await self._generate_request_func(async_w3, all_middlewares), - ) - return self._request_func_cache[-1] - - async def _generate_request_func( - self, async_w3: "AsyncWeb3", middlewares: Sequence[AsyncMiddleware] - ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: - return await async_combine_middlewares( - middlewares=middlewares, - async_w3=async_w3, - provider_request_fn=self.make_request, - ) - async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: raise NotImplementedError("Providers must implement this method") diff --git a/web3/providers/base.py b/web3/providers/base.py index d7877546cd..b88740db3c 100644 --- a/web3/providers/base.py +++ b/web3/providers/base.py @@ -20,9 +20,7 @@ from web3.exceptions import ( ProviderConnectionError, ) -from web3.middleware import ( - combine_middlewares, -) +from web3.middleware.base import Web3Middleware from web3.types import ( Middleware, MiddlewareOnion, @@ -35,7 +33,7 @@ class BaseProvider: - _middlewares: Tuple[Middleware, ...] = () + _middlewares: Tuple[Web3Middleware, ...] = () # a tuple of (all_middlewares, request_func) _request_func_cache: Tuple[Tuple[Middleware, ...], Callable[..., RPCResponse]] = ( None, @@ -48,7 +46,7 @@ class BaseProvider: ccip_read_max_redirects: int = 4 @property - def middlewares(self) -> Tuple[Middleware, ...]: + def middlewares(self) -> Tuple[Web3Middleware, ...]: return self._middlewares @middlewares.setter @@ -56,35 +54,6 @@ def middlewares(self, values: MiddlewareOnion) -> None: # tuple(values) converts to MiddlewareOnion -> Tuple[Middleware, ...] self._middlewares = tuple(values) # type: ignore - def request_func( - self, w3: "Web3", outer_middlewares: MiddlewareOnion - ) -> Callable[..., RPCResponse]: - """ - @param outer_middlewares is an iterable of middlewares, - ordered by first to execute - @returns a function that calls all the middleware and - eventually self.make_request() - """ - # type ignored b/c tuple(MiddlewareOnion) converts to tuple of middlewares - all_middlewares: Tuple[Middleware] = tuple(outer_middlewares) + tuple(self.middlewares) # type: ignore # noqa: E501 - - cache_key = self._request_func_cache[0] - if cache_key is None or cache_key != all_middlewares: - self._request_func_cache = ( - all_middlewares, - self._generate_request_func(w3, all_middlewares), - ) - return self._request_func_cache[-1] - - def _generate_request_func( - self, w3: "Web3", middlewares: Sequence[Middleware] - ) -> Callable[..., RPCResponse]: - return combine_middlewares( - middlewares=middlewares, - w3=w3, - provider_request_fn=self.make_request, - ) - def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: raise NotImplementedError("Providers must implement this method") diff --git a/web3/providers/eth_tester/main.py b/web3/providers/eth_tester/main.py index 069aee3603..8952dbb746 100644 --- a/web3/providers/eth_tester/main.py +++ b/web3/providers/eth_tester/main.py @@ -21,7 +21,6 @@ Literal, ) from web3.middleware.attrdict import ( - async_attrdict_middleware, attrdict_middleware, ) from web3.middleware.buffered_gas_estimate import ( @@ -40,8 +39,6 @@ ) from .middleware import ( - async_default_transaction_fields_middleware, - async_ethereum_tester_middleware, default_transaction_fields_middleware, ethereum_tester_middleware, ) @@ -53,10 +50,9 @@ class AsyncEthereumTesterProvider(AsyncBaseProvider): middlewares = ( - async_attrdict_middleware, - async_buffered_gas_estimate_middleware, - async_default_transaction_fields_middleware, - async_ethereum_tester_middleware, + attrdict_middleware, + # async_default_transaction_fields_middleware, + # async_ethereum_tester_middleware, ) def __init__(self) -> None: diff --git a/web3/providers/eth_tester/middleware.py b/web3/providers/eth_tester/middleware.py index a50e949b2a..f51a122cca 100644 --- a/web3/providers/eth_tester/middleware.py +++ b/web3/providers/eth_tester/middleware.py @@ -40,11 +40,8 @@ from web3._utils.method_formatters import ( apply_list_to_array_formatter, ) -from web3.middleware import ( - construct_formatting_middleware, -) from web3.middleware.formatting import ( - async_construct_formatting_middleware, + FormattingMiddleware, ) from web3.types import ( AsyncMiddlewareCoroutine, @@ -316,7 +313,7 @@ def is_hexstr(value: Any) -> bool: } -ethereum_tester_middleware = construct_formatting_middleware( +ethereum_tester_middleware = FormattingMiddleware( request_formatters=request_formatters, result_formatters=result_formatters ) @@ -367,13 +364,13 @@ def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: # --- async --- # -async def async_ethereum_tester_middleware( # type: ignore - make_request, web3: "AsyncWeb3" -) -> Middleware: - middleware = await async_construct_formatting_middleware( - request_formatters=request_formatters, result_formatters=result_formatters - ) - return await middleware(make_request, web3) +# async def async_ethereum_tester_middleware( # type: ignore +# make_request, web3: "AsyncWeb3" +# ) -> Middleware: +# middleware = await async_construct_formatting_middleware( +# request_formatters=request_formatters, result_formatters=result_formatters +# ) +# return await middleware(make_request, web3) async def async_guess_from( diff --git a/web3/providers/rpc/__init__.py b/web3/providers/rpc/__init__.py new file mode 100644 index 0000000000..cbac394616 --- /dev/null +++ b/web3/providers/rpc/__init__.py @@ -0,0 +1,6 @@ +from .async_rpc import ( + AsyncHTTPProvider, +) +from .rpc import ( + HTTPProvider, +) diff --git a/web3/providers/async_rpc.py b/web3/providers/rpc/async_rpc.py similarity index 58% rename from web3/providers/async_rpc.py rename to web3/providers/rpc/async_rpc.py index e5ab8cb79d..95915ad1fa 100644 --- a/web3/providers/async_rpc.py +++ b/web3/providers/rpc/async_rpc.py @@ -1,3 +1,4 @@ +import asyncio import logging from typing import ( Any, @@ -32,15 +33,16 @@ RPCResponse, ) -from ..datastructures import ( +from ...datastructures import ( NamedElementOnion, ) -from ..middleware.exception_retry_request import ( - async_http_retry_request_middleware, -) -from .async_base import ( +from ..async_base import ( AsyncJSONBaseProvider, ) +from .utils import ( + check_if_retry_on_failure, + ExceptionRetryConfiguration, +) class AsyncHTTPProvider(AsyncJSONBaseProvider): @@ -48,12 +50,13 @@ class AsyncHTTPProvider(AsyncJSONBaseProvider): endpoint_uri = None _request_kwargs = None # type ignored b/c conflict with _middlewares attr on AsyncBaseProvider - _middlewares: Tuple[AsyncMiddleware, ...] = NamedElementOnion([(async_http_retry_request_middleware, "http_retry_request")]) # type: ignore # noqa: E501 + _middlewares: Tuple[AsyncMiddleware, ...] = NamedElementOnion([]) # type: ignore def __init__( self, endpoint_uri: Optional[Union[URI, str]] = None, request_kwargs: Optional[Any] = None, + exception_retry_configuration: Optional[ExceptionRetryConfiguration] = None, ) -> None: if endpoint_uri is None: self.endpoint_uri = get_default_http_endpoint() @@ -62,6 +65,12 @@ def __init__( self._request_kwargs = request_kwargs or {} + self.exception_retry_configuration = ( + exception_retry_configuration + # use default values if not provided + or ExceptionRetryConfiguration() + ) + super().__init__() async def cache_async_session(self, session: ClientSession) -> ClientSession: @@ -83,14 +92,41 @@ def get_request_headers(self) -> Dict[str, str]: "User-Agent": construct_user_agent(str(type(self))), } + async def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes: + """ + If exception_retry_configuration is set, retry on failure; otherwise, make + the request without retrying. + """ + if ( + self.exception_retry_configuration is not None + and check_if_retry_on_failure( + method, self.exception_retry_configuration.method_allowlist + ) + ): + for i in range(self.exception_retry_configuration.retries): + try: + return await async_make_post_request( + self.endpoint_uri, request_data, **self.get_request_kwargs() + ) + except tuple(self.exception_retry_configuration.errors): + if i < self.exception_retry_configuration.retries - 1: + await asyncio.sleep( + self.exception_retry_configuration.backoff_factor + ) + continue + else: + raise + else: + return await async_make_post_request( + self.endpoint_uri, request_data, **self.get_request_kwargs() + ) + async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: self.logger.debug( f"Making request HTTP. URI: {self.endpoint_uri}, Method: {method}" ) request_data = self.encode_rpc_request(method, params) - raw_response = await async_make_post_request( - self.endpoint_uri, request_data, **self.get_request_kwargs() - ) + raw_response = await self._make_request(method, request_data) response = self.decode_rpc_response(raw_response) self.logger.debug( f"Getting response HTTP. URI: {self.endpoint_uri}, " diff --git a/web3/providers/rpc.py b/web3/providers/rpc/rpc.py similarity index 58% rename from web3/providers/rpc.py rename to web3/providers/rpc/rpc.py index bffa69392a..70f26af700 100644 --- a/web3/providers/rpc.py +++ b/web3/providers/rpc/rpc.py @@ -1,4 +1,5 @@ import logging +import time from typing import ( Any, Dict, @@ -26,18 +27,19 @@ from web3.datastructures import ( NamedElementOnion, ) -from web3.middleware import ( - http_retry_request_middleware, -) from web3.types import ( Middleware, RPCEndpoint, RPCResponse, ) -from .base import ( +from ..base import ( JSONBaseProvider, ) +from .utils import ( + check_if_retry_on_failure, + ExceptionRetryConfiguration, +) class HTTPProvider(JSONBaseProvider): @@ -46,13 +48,16 @@ class HTTPProvider(JSONBaseProvider): _request_args = None _request_kwargs = None # type ignored b/c conflict with _middlewares attr on BaseProvider - _middlewares: Tuple[Middleware, ...] = NamedElementOnion([(http_retry_request_middleware, "http_retry_request")]) # type: ignore # noqa: E501 + _middlewares: Tuple[Middleware, ...] = NamedElementOnion([]) # type: ignore # noqa: E501 def __init__( self, endpoint_uri: Optional[Union[URI, str]] = None, request_kwargs: Optional[Any] = None, session: Optional[Any] = None, + exception_retry_configuration: Optional[ + ExceptionRetryConfiguration + ] = ExceptionRetryConfiguration(), ) -> None: if endpoint_uri is None: self.endpoint_uri = get_default_http_endpoint() @@ -60,6 +65,7 @@ def __init__( self.endpoint_uri = URI(endpoint_uri) self._request_kwargs = request_kwargs or {} + self.exception_retry_configuration = exception_retry_configuration if session: cache_and_return_session(self.endpoint_uri, session) @@ -82,14 +88,39 @@ def get_request_headers(self) -> Dict[str, str]: "User-Agent": construct_user_agent(str(type(self))), } + def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes: + """ + If exception_retry_configuration is set, retry on failure; otherwise, make + the request without retrying. + """ + if ( + self.exception_retry_configuration is not None + and check_if_retry_on_failure( + method, self.exception_retry_configuration.method_allowlist + ) + ): + for i in range(self.exception_retry_configuration.retries): + try: + return make_post_request( + self.endpoint_uri, request_data, **self.get_request_kwargs() + ) + except tuple(self.exception_retry_configuration.errors): + if i < self.exception_retry_configuration.retries - 1: + time.sleep(self.exception_retry_configuration.backoff_factor) + continue + else: + raise + else: + return make_post_request( + self.endpoint_uri, request_data, **self.get_request_kwargs() + ) + def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: self.logger.debug( f"Making request HTTP. URI: {self.endpoint_uri}, Method: {method}" ) request_data = self.encode_rpc_request(method, params) - raw_response = make_post_request( - self.endpoint_uri, request_data, **self.get_request_kwargs() - ) + raw_response = self._make_request(method, request_data) response = self.decode_rpc_response(raw_response) self.logger.debug( f"Getting response HTTP. URI: {self.endpoint_uri}, " diff --git a/web3/providers/rpc/utils.py b/web3/providers/rpc/utils.py new file mode 100644 index 0000000000..c829448dc2 --- /dev/null +++ b/web3/providers/rpc/utils.py @@ -0,0 +1,92 @@ +import asyncio +from typing import Sequence, Type + +import requests +from pydantic import BaseModel + +from web3.types import RPCEndpoint + + +REQUEST_RETRY_ALLOWLIST = [ + "admin", + "miner", + "net", + "txpool", + "testing", + "evm", + "eth_protocolVersion", + "eth_syncing", + "eth_coinbase", + "eth_mining", + "eth_hashrate", + "eth_chainId", + "eth_gasPrice", + "eth_accounts", + "eth_blockNumber", + "eth_getBalance", + "eth_getStorageAt", + "eth_getProof", + "eth_getCode", + "eth_getBlockByNumber", + "eth_getBlockByHash", + "eth_getBlockTransactionCountByNumber", + "eth_getBlockTransactionCountByHash", + "eth_getUncleCountByBlockNumber", + "eth_getUncleCountByBlockHash", + "eth_getTransactionByHash", + "eth_getTransactionByBlockHashAndIndex", + "eth_getTransactionByBlockNumberAndIndex", + "eth_getTransactionReceipt", + "eth_getTransactionCount", + "eth_getRawTransactionByHash", + "eth_call", + "eth_estimateGas", + "eth_createAccessList", + "eth_maxPriorityFeePerGas", + "eth_newBlockFilter", + "eth_newPendingTransactionFilter", + "eth_newFilter", + "eth_getFilterChanges", + "eth_getFilterLogs", + "eth_getLogs", + "eth_uninstallFilter", + "eth_getCompilers", + "eth_getWork", + "eth_sign", + "eth_signTypedData", + "eth_sendRawTransaction", + "personal_importRawKey", + "personal_newAccount", + "personal_listAccounts", + "personal_listWallets", + "personal_lockAccount", + "personal_unlockAccount", + "personal_ecRecover", + "personal_sign", + "personal_signTypedData", +] + + +def check_if_retry_on_failure( + method: RPCEndpoint, + allowlist: Sequence[str] = None, +) -> bool: + if allowlist is None: + allowlist = REQUEST_RETRY_ALLOWLIST + + if method in allowlist or method.split("_")[0]: + return True + else: + return False + + +class ExceptionRetryConfiguration(BaseModel): + errors: Sequence[Type[BaseException]] = ( + ConnectionError, + requests.HTTPError, + asyncio.Timeout, + requests.Timeout, + ) + retries: int = 5 + backoff_factor: float = 0.5 + method_allowlist: Sequence[str] = REQUEST_RETRY_ALLOWLIST From 27c2e62416790fafed55d12aa141c2a6ab5c6948 Mon Sep 17 00:00:00 2001 From: fselmo Date: Thu, 16 Nov 2023 13:54:54 +0300 Subject: [PATCH 02/24] Refactor new middleware onion setup --- web3/manager.py | 46 +++---- web3/middleware/__init__.py | 107 ++++++++++++--- web3/middleware/attrdict.py | 14 +- web3/middleware/base.py | 48 ++++--- web3/middleware/buffered_gas_estimate.py | 41 +++--- web3/middleware/formatting.py | 25 ++-- web3/middleware/gas_price_strategy.py | 24 ++-- web3/middleware/names.py | 37 +++-- .../{geth_poa.py => proof_of_authority.py} | 6 - web3/middleware/signing.py | 129 ++++++------------ web3/middleware/stalecheck.py | 128 ++++++----------- web3/providers/async_base.py | 31 ++++- web3/providers/base.py | 42 ++++-- web3/providers/eth_tester/main.py | 7 +- web3/providers/eth_tester/middleware.py | 83 +++++------ 15 files changed, 380 insertions(+), 388 deletions(-) rename web3/middleware/{geth_poa.py => proof_of_authority.py} (93%) diff --git a/web3/manager.py b/web3/manager.py index 43d1175864..770c040d40 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -38,6 +38,7 @@ ) from web3.middleware import ( attrdict_middleware, + buffered_gas_estimate_middleware, gas_price_strategy_middleware, ens_name_to_address_middleware, validation_middleware, @@ -63,6 +64,9 @@ AsyncWeb3, Web3, ) + from web3.middleware.base import ( # noqa: F401 + Web3Middleware, + ) from web3.providers import ( # noqa: F401 AsyncBaseProvider, BaseProvider, @@ -70,9 +74,6 @@ from web3.providers.websocket.request_processor import ( # noqa: F401 RequestProcessor, ) - from web3.middleware.base import ( # noqa: F401 - Web3Middleware, - ) NULL_RESPONSES = [None, HexBytes("0x"), "0x"] @@ -169,10 +170,10 @@ def default_middlewares(w3: "Web3") -> List[Tuple["Web3Middleware", str]]: """ return [ (gas_price_strategy_middleware, "gas_price_strategy"), - # (ens_name_to_address_middleware, "name_to_address"), + (ens_name_to_address_middleware, "name_to_address"), (attrdict_middleware, "attrdict"), (validation_middleware, "validation"), - # (async_buffered_gas_estimate_middleware, "gas_estimate"), + (buffered_gas_estimate_middleware, "gas_estimate"), ] # @@ -181,36 +182,23 @@ def default_middlewares(w3: "Web3") -> List[Tuple["Web3Middleware", str]]: def _make_request( self, method: Union[RPCEndpoint, Callable[..., RPCEndpoint]], params: Any ) -> RPCResponse: - """ - 1. Pipe the request params through the middleware stack - 2. Make the request using the provider - 3. Pipe the raw response through the middleware stack - """ + provider = cast("BaseProvider", self.provider) + request_func = provider.request_func( + cast("Web3", self.w3), cast(MiddlewareOnion, self.middleware_onion) + ) self.logger.debug(f"Making request. Method: {method}") - for middleware, _name in self.middleware_onion.middlewares: - params = middleware.process_request_params(self.w3, method, params) - response = self.provider.make_request(method, params) - for middleware, _name in reversed(self.middleware_onion.middlewares): - response = middleware.process_response(self.w3, method, response) - return response + return request_func(method, params) async def _coro_make_request( self, method: Union[RPCEndpoint, Callable[..., RPCEndpoint]], params: Any ) -> RPCResponse: - """ - 1. Pipe the request params through the middleware stack - 2. Make the request using the provider - 3. Pipe the raw response through the middleware stack - """ + provider = cast("AsyncBaseProvider", self.provider) + request_func = await provider.request_func( + cast("AsyncWeb3", self.w3), + cast(AsyncMiddlewareOnion, self.middleware_onion), + ) self.logger.debug(f"Making request. Method: {method}") - for middleware, _name in self.middleware_onion.middlewares: - params = await middleware.async_process_request_params( - self.w3, method, params - ) - response = await self.provider.make_request(method, params) - for middleware, _name in reversed(self.middleware_onion.middlewares): - response = await middleware.async_process_response(self.w3, response) - return response + return await request_func(method, params) # # formatted_response parses and validates JSON-RPC responses for expected diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index 5a5fd489c1..0e57026cde 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -1,3 +1,17 @@ +import functools +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Coroutine, + Sequence, +) + +from web3.types import ( + RPCEndpoint, + RPCResponse, +) + from .abi import ( abi_middleware, ) @@ -8,8 +22,10 @@ from .attrdict import ( attrdict_middleware, ) +from .base import ( + Web3Middleware, +) from .buffered_gas_estimate import ( - async_buffered_gas_estimate_middleware, buffered_gas_estimate_middleware, ) from .cache import ( @@ -20,29 +36,22 @@ construct_simple_cache_middleware, construct_time_based_cache_middleware, ) -from .exception_handling import ( - construct_exception_handler_middleware, -) -from .filter import ( - async_local_filter_middleware, - local_filter_middleware, -) -from .fixture import ( - async_construct_error_generator_middleware, - async_construct_result_generator_middleware, - construct_error_generator_middleware, - construct_fixture_middleware, - construct_result_generator_middleware, -) from .gas_price_strategy import ( gas_price_strategy_middleware, ) -from .geth_poa import ( +from .proof_of_authority import ( extradata_to_poa_middleware, ) from .names import ( ens_name_to_address_middleware, ) +from .filter import ( + async_local_filter_middleware, + local_filter_middleware, +) +from .gas_price_strategy import ( + GasPriceStrategyMiddleware, +) from .normalize_request_parameters import ( request_parameter_normalizer, ) @@ -50,12 +59,72 @@ pythonic_middleware, ) from .signing import ( - construct_sign_and_send_raw_middleware, + SignAndSendRawMiddleware, ) from .stalecheck import ( - async_make_stalecheck_middleware, - make_stalecheck_middleware, + StaleCheckMiddleware, ) from .validation import ( validation_middleware, ) + +if TYPE_CHECKING: + from web3 import AsyncWeb3, Web3 + + +def combine_middlewares( + middlewares: Sequence[Web3Middleware], + w3: "Web3", + provider_request_fn: Callable[[RPCEndpoint, Any], Any], +) -> Callable[..., RPCResponse]: + """ + Returns a callable function which takes method and params as positional arguments + and passes these args through the request processors, makes the request, and passes + the response through the response processors. + """ + [setattr(middleware, "_w3", w3) for middleware in middlewares] + request_processors = [middleware.request_processor for middleware in middlewares] + response_processors = [ + middleware.response_processor for middleware in reversed(middlewares) + ] + return lambda method, params_or_response: functools.reduce( + lambda p_o_r, processor: processor(method, p_o_r), + response_processors, + provider_request_fn( + method, + functools.reduce( + lambda p, processor: processor(method, p), + request_processors, + params_or_response, + ), + ), + ) + + +async def async_combine_middlewares( + middlewares: Sequence[Web3Middleware], + async_w3: "AsyncWeb3", + provider_request_fn: Callable[[RPCEndpoint, Any], Any], +) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: + """ + Returns a callable function which takes method and params as positional arguments + and passes these args through the request processors, makes the request, and passes + the response through the response processors. + """ + [setattr(middleware, "_w3", async_w3) for middleware in middlewares] + async_request_processors = [ + middleware.async_request_processor for middleware in middlewares + ] + async_response_processors = [ + middleware.async_response_processor for middleware in reversed(middlewares) + ] + + async def async_request_fn(method: RPCEndpoint, params: Any) -> RPCResponse: + for processor in async_request_processors: + params = await processor(method, params) + response = await provider_request_fn(method, params) + for processor in async_response_processors: + response = await processor(method, response) + return response + + return async_request_fn diff --git a/web3/middleware/attrdict.py b/web3/middleware/attrdict.py index e433071243..3c0eddbd19 100644 --- a/web3/middleware/attrdict.py +++ b/web3/middleware/attrdict.py @@ -2,8 +2,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, - Optional, cast, ) @@ -58,9 +56,7 @@ class AttributeDictMiddleware(Web3Middleware, ABC): (e.g. my_attribute_dict.property1) will not preserve typing. """ - def process_response( - self, w3: "Web3", method: "RPCEndpoint", response: "RPCResponse" - ) -> Any: + def response_processor(self, method: "RPCEndpoint", response: "RPCResponse") -> Any: if "result" in response: return assoc( response, "result", AttributeDict.recursive(response["result"]) @@ -70,12 +66,12 @@ def process_response( # -- async -- # - async def async_process_response( - self, async_w3: "AsyncWeb3", method: "RPCEndpoint", response: "RPCResponse" + async def async_response_processor( + self, method: "RPCEndpoint", response: "RPCResponse" ) -> Any: - if async_w3.provider.has_persistent_connection: + if self._w3.provider.has_persistent_connection: # asynchronous response processing - provider = cast("PersistentConnectionProvider", async_w3.provider) + provider = cast("PersistentConnectionProvider", self._w3.provider) provider._request_processor.append_middleware_response_processor( response, _handle_async_response ) diff --git a/web3/middleware/base.py b/web3/middleware/base.py index 49b83a640b..2573199617 100644 --- a/web3/middleware/base.py +++ b/web3/middleware/base.py @@ -1,9 +1,11 @@ -from abc import abstractmethod -from typing import Any, TYPE_CHECKING - +from typing import ( + TYPE_CHECKING, + Any, + TypeVar, +) if TYPE_CHECKING: - from web3 import ( + from web3 import ( # noqa: F401 AsyncWeb3, Web3, ) @@ -13,29 +15,37 @@ ) +WEB3 = TypeVar("WEB3", "AsyncWeb3", "Web3") + + class Web3Middleware: - @abstractmethod - def process_request_params( - self, w3: "Web3", method: "RPCEndpoint", params: Any - ) -> Any: + """ + Base class for web3.py middleware. This class is not meant to be used directly, + but instead inherited from. + """ + + _w3: WEB3 + + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: return params - @abstractmethod - def process_response( - self, w3: "Web3", method: "RPCEndpoint", response: "RPCResponse" - ) -> Any: + def response_processor( + self, method: "RPCEndpoint", response: "RPCResponse" + ) -> "RPCResponse": return response # -- async -- # - @abstractmethod - async def async_process_request_params( - self, async_w3: "AsyncWeb3", method: "RPCEndpoint", params: Any + async def async_request_processor( + self, + method: "RPCEndpoint", + params: Any, ) -> Any: return params - @abstractmethod - async def async_process_response( - self, async_w3: "AsyncWeb3", method: "RPCEndpoint", response: "RPCResponse" - ) -> Any: + async def async_response_processor( + self, + method: "RPCEndpoint", + response: "RPCResponse", + ) -> "RPCResponse": return response diff --git a/web3/middleware/buffered_gas_estimate.py b/web3/middleware/buffered_gas_estimate.py index 60f1cc739b..0e79271cdd 100644 --- a/web3/middleware/buffered_gas_estimate.py +++ b/web3/middleware/buffered_gas_estimate.py @@ -1,7 +1,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, ) from eth_utils.toolz import ( @@ -14,10 +13,11 @@ from web3._utils.transactions import ( get_buffered_gas_estimate, ) +from web3.middleware.base import ( + Web3Middleware, +) from web3.types import ( - AsyncMiddlewareCoroutine, RPCEndpoint, - RPCResponse, ) if TYPE_CHECKING: @@ -27,34 +27,35 @@ ) -def buffered_gas_estimate_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" -) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: +class BufferedGasEstimateMiddleware(Web3Middleware): + """ + Includes a gas estimate for all transactions that do not already have a gas value. + """ + + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method == "eth_sendTransaction": transaction = params[0] if "gas" not in transaction: transaction = assoc( transaction, "gas", - hex(get_buffered_gas_estimate(w3, transaction)), + hex(get_buffered_gas_estimate(self._w3, transaction)), ) - return make_request(method, [transaction]) - return make_request(method, params) - - return middleware + params = (transaction,) + return params + # -- async -- # -async def async_buffered_gas_estimate_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "AsyncWeb3" -) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: + async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method == "eth_sendTransaction": transaction = params[0] if "gas" not in transaction: - gas_estimate = await async_get_buffered_gas_estimate(w3, transaction) + gas_estimate = await async_get_buffered_gas_estimate( + self._w3, transaction + ) transaction = assoc(transaction, "gas", hex(gas_estimate)) - return await make_request(method, [transaction]) - return await make_request(method, params) + params = (transaction,) + return params + - return middleware +buffered_gas_estimate_middleware = BufferedGasEstimateMiddleware() diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index 356dd398c3..8d96a64b84 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -4,7 +4,6 @@ Callable, Coroutine, Optional, - TypeVar, cast, ) @@ -127,13 +126,11 @@ def __init__( self.sync_formatters_builder = sync_formatters_builder self.async_formatters_builder = async_formatters_builder - def process_request_params( - self, w3: "Web3", method: "RPCEndpoint", params: Any - ) -> Any: + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if self.sync_formatters_builder is not None: formatters = merge( FORMATTER_DEFAULTS, - self.sync_formatters_builder(w3, method), + self.sync_formatters_builder(self._w3, method), ) self.request_formatters = formatters.pop("request_formatters") @@ -143,13 +140,11 @@ def process_request_params( return params - def process_response( - self, w3: "Web3", method: RPCEndpoint, response: "RPCResponse" - ) -> Any: + def response_processor(self, method: RPCEndpoint, response: "RPCResponse") -> Any: if self.sync_formatters_builder is not None: formatters = merge( FORMATTER_DEFAULTS, - self.sync_formatters_builder(w3, method), + self.sync_formatters_builder(self._w3, method), ) self.result_formatters = formatters["result_formatters"] self.error_formatters = formatters["error_formatters"] @@ -163,13 +158,11 @@ def process_response( # -- async -- # - async def async_process_request_params( - self, async_w3: "AsyncWeb3", method: "RPCEndpoint", params: Any - ) -> Any: + async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if self.async_formatters_builder is not None: formatters = merge( FORMATTER_DEFAULTS, - await self.async_formatters_builder(async_w3, method), + await self.async_formatters_builder(self._w3, method), ) self.request_formatters = formatters.pop("request_formatters") @@ -179,13 +172,13 @@ async def async_process_request_params( return params - async def async_process_response( - self, async_w3: "AsyncWeb3", method: RPCEndpoint, response: "RPCResponse" + async def async_response_processor( + self, method: RPCEndpoint, response: "RPCResponse" ) -> Any: if self.async_formatters_builder is not None: formatters = merge( FORMATTER_DEFAULTS, - await self.async_formatters_builder(async_w3, method), + await self.async_formatters_builder(self._w3, method), ) self.result_formatters = formatters["result_formatters"] self.error_formatters = formatters["error_formatters"] diff --git a/web3/middleware/gas_price_strategy.py b/web3/middleware/gas_price_strategy.py index ac6e4a8090..48bf351a49 100644 --- a/web3/middleware/gas_price_strategy.py +++ b/web3/middleware/gas_price_strategy.py @@ -1,4 +1,3 @@ -from abc import ABC from typing import ( TYPE_CHECKING, Any, @@ -79,7 +78,7 @@ def validate_transaction_params( return transaction -class GasPriceStrategyMiddleware(Web3Middleware, ABC): +class GasPriceStrategyMiddleware(Web3Middleware): """ - Uses a gas price strategy if one is set. This is only supported for legacy transactions. It is recommended to send dynamic fee transactions @@ -88,32 +87,29 @@ class GasPriceStrategyMiddleware(Web3Middleware, ABC): - Validates transaction params against legacy and dynamic fee txn values. """ - def process_request_params( - self, w3: "Web3", method: RPCEndpoint, params: Any - ) -> Any: + def request_processor(self, method: RPCEndpoint, params: Any) -> Any: if method == "eth_sendTransaction": transaction = params[0] - generated_gas_price = w3.eth.generate_gas_price(transaction) - latest_block = w3.eth.get_block("latest") + generated_gas_price = self._w3.eth.generate_gas_price(transaction) + latest_block = self._w3.eth.get_block("latest") transaction = validate_transaction_params( transaction, latest_block, generated_gas_price ) - return (transaction,) + params = (transaction,) + return params # -- async -- # - async def async_process_request_params( - self, async_w3: "AsyncWeb3", method: RPCEndpoint, params: Any - ) -> Any: + async def async_request_processor(self, method: RPCEndpoint, params: Any) -> Any: if method == "eth_sendTransaction": transaction = params[0] - generated_gas_price = async_w3.eth.generate_gas_price(transaction) - latest_block = await async_w3.eth.get_block("latest") + generated_gas_price = self._w3.eth.generate_gas_price(transaction) + latest_block = await self._w3.eth.get_block("latest") transaction = validate_transaction_params( transaction, latest_block, generated_gas_price ) - return (transaction,) + params = (transaction,) return params diff --git a/web3/middleware/names.py b/web3/middleware/names.py index 54ca201e15..8376933a36 100644 --- a/web3/middleware/names.py +++ b/web3/middleware/names.py @@ -1,4 +1,3 @@ -from abc import ABC from typing import ( TYPE_CHECKING, Any, @@ -22,7 +21,9 @@ from web3.types import ( RPCEndpoint, ) -from .base import Web3Middleware +from .base import ( + Web3Middleware, +) from .._utils.abi import ( abi_data_tree, @@ -96,26 +97,22 @@ async def async_apply_ens_to_address_conversion( ) -class EnsNameToAddressMiddleware(Web3Middleware, ABC): - def __init__(self) -> None: - # the sync version of this middleware utilizes the FormattingMiddleware - self._formatting_middleware = FormattingMiddleware( - request_formatters=abi_request_formatters( # type: ignore - [abi_ens_resolver], - RPC_ABIS, # type: ignore - ) - ) +class EnsNameToAddressMiddleware(Web3Middleware): + _formatting_middleware = None - def process_request_params( - self, w3: "Web3", method: "RPCEndpoint", params: Any - ) -> Any: - return self._formatting_middleware.process_request_params(w3, method, params) + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if self._formatting_middleware is None: + normalizers = [ + abi_ens_resolver(self._w3), + ] + self._formatting_middleware = FormattingMiddleware( + request_formatters=abi_request_formatters(normalizers, RPC_ABIS) # type: ignore # noqa: E501 + ) + return self._formatting_middleware.request_processor(method, params) # -- async -- # - async def async_process_request_params( - self, async_w3: "AsyncWeb3", method: "RPCEndpoint", params: Any - ) -> Any: + async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: abi_types_for_method = RPC_ABIS.get(method, None) if abi_types_for_method is not None: @@ -123,7 +120,7 @@ async def async_process_request_params( # eth_subscribe optional logs params are unique. # Handle them separately here. (formatted_dict,) = await async_apply_ens_to_address_conversion( - async_w3, + self._w3, (params[1],), { "address": "address", @@ -134,7 +131,7 @@ async def async_process_request_params( else: params = await async_apply_ens_to_address_conversion( - async_w3, + self._w3, params, abi_types_for_method, ) diff --git a/web3/middleware/geth_poa.py b/web3/middleware/proof_of_authority.py similarity index 93% rename from web3/middleware/geth_poa.py rename to web3/middleware/proof_of_authority.py index e27b493394..a655a093fa 100644 --- a/web3/middleware/geth_poa.py +++ b/web3/middleware/proof_of_authority.py @@ -1,7 +1,5 @@ from typing import ( TYPE_CHECKING, - Any, - Callable, ) from eth_utils import ( @@ -25,10 +23,6 @@ RPC, ) from web3.middleware.formatting import FormattingMiddleware -from web3.types import ( - AsyncMiddlewareCoroutine, - RPCEndpoint, -) if TYPE_CHECKING: from web3 import ( # noqa: F401 diff --git a/web3/middleware/signing.py b/web3/middleware/signing.py index 0ccc629f21..2d4805d09e 100644 --- a/web3/middleware/signing.py +++ b/web3/middleware/signing.py @@ -5,7 +5,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Collection, Iterable, Tuple, @@ -52,12 +51,11 @@ fill_nonce, fill_transaction_defaults, ) +from web3.middleware.base import ( + Web3Middleware, +) from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareCoroutine, - Middleware, RPCEndpoint, - RPCResponse, TxParams, ) @@ -141,99 +139,62 @@ def format_transaction(transaction: TxParams) -> TxParams: ) -def construct_sign_and_send_raw_middleware( - private_key_or_account: Union[_PrivateKey, Collection[_PrivateKey]] -) -> Middleware: - """Capture transactions sign and send as raw transactions - - - Keyword arguments: - private_key_or_account -- A single private key or a tuple, - list or set of private keys. Keys can be any of the following formats: - - An eth_account.LocalAccount object - - An eth_keys.PrivateKey object - - A raw private key as a hex string or byte string - """ - - accounts = gen_normalized_accounts(private_key_or_account) +class SignAndSendRawMiddleware(Web3Middleware): + format_and_fill_tx = None - def sign_and_send_raw_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - format_and_fill_tx = compose( - format_transaction, fill_transaction_defaults(w3), fill_nonce(w3) - ) + def __init__( + self, private_key_or_account: Union[_PrivateKey, Collection[_PrivateKey]] + ): + self._accounts = gen_normalized_accounts(private_key_or_account) + + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if method != "eth_sendTransaction": + return params + else: + if self.format_and_fill_tx is None: + self.format_and_fill_tx = compose( + format_transaction, + fill_transaction_defaults(self._w3), + fill_nonce(self._w3), + ) + + filled_transaction = self.format_and_fill_tx(params[0]) + tx_from = filled_transaction.get("from", None) - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method != "eth_sendTransaction": - return make_request(method, params) + if tx_from is None or ( + tx_from is not None and tx_from not in self._accounts + ): + return params else: - transaction = format_and_fill_tx(params[0]) - - if "from" not in transaction: - return make_request(method, params) - elif transaction.get("from") not in accounts: - return make_request(method, params) - - account = accounts[transaction["from"]] - raw_tx = account.sign_transaction(transaction).rawTransaction - - return make_request(RPCEndpoint("eth_sendRawTransaction"), [raw_tx.hex()]) - - return middleware + account = self._accounts[to_checksum_address(tx_from)] + raw_tx = account.sign_transaction(filled_transaction).rawTransaction - return sign_and_send_raw_middleware + return (raw_tx.hex(),) + # -- async -- # -# -- async -- # - - -async def async_construct_sign_and_send_raw_middleware( - private_key_or_account: Union[_PrivateKey, Collection[_PrivateKey]] -) -> AsyncMiddleware: - """ - Capture transactions & sign and send as raw transactions - - Keyword arguments: - private_key_or_account -- A single private key or a tuple, - list or set of private keys. Keys can be any of the following formats: - - An eth_account.LocalAccount object - - An eth_keys.PrivateKey object - - A raw private key as a hex string or byte string - """ - accounts = gen_normalized_accounts(private_key_or_account) - - async def async_sign_and_send_raw_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" - ) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method != "eth_sendTransaction": - # quick exit if not `eth_sendTransaction` - return await make_request(method, params) + async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if method != "eth_sendTransaction": + return params + else: formatted_transaction = format_transaction(params[0]) filled_transaction = await async_fill_transaction_defaults( - async_w3, + self._w3, formatted_transaction, ) filled_transaction = await async_fill_nonce( - async_w3, + self._w3, filled_transaction, ) - tx_from = filled_transaction.get("from", None) - if tx_from is None or (tx_from is not None and tx_from not in accounts): - return await make_request(method, params) - - account = accounts[to_checksum_address(tx_from)] - raw_tx = account.sign_transaction(filled_transaction).rawTransaction - - return await make_request( - RPCEndpoint("eth_sendRawTransaction"), - [raw_tx.hex()], - ) - - return middleware + if tx_from is None or ( + tx_from is not None and tx_from not in self._accounts + ): + return params + else: + account = self._accounts[to_checksum_address(tx_from)] + raw_tx = account.sign_transaction(filled_transaction).rawTransaction - return async_sign_and_send_raw_middleware + return (raw_tx.hex(),) diff --git a/web3/middleware/stalecheck.py b/web3/middleware/stalecheck.py index 16b98715a3..65cacc8956 100644 --- a/web3/middleware/stalecheck.py +++ b/web3/middleware/stalecheck.py @@ -11,13 +11,12 @@ from web3.exceptions import ( StaleBlockchain, ) +from web3.middleware.base import ( + Web3Middleware, +) from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareCoroutine, BlockData, - Middleware, RPCEndpoint, - RPCResponse, ) if TYPE_CHECKING: @@ -35,86 +34,41 @@ def _is_fresh(block: BlockData, allowable_delay: int) -> bool: return False -def make_stalecheck_middleware( - allowable_delay: int, - skip_stalecheck_for_methods: Collection[str] = SKIP_STALECHECK_FOR_METHODS, -) -> Middleware: - """ - Use to require that a function will run only of the blockchain is recently updated. - - This middleware takes an argument, so unlike other middleware, you must make the - middleware with a method call. - - For example: `make_stalecheck_middleware(60*5)` - - If the latest block in the chain is older than 5 minutes in this example, then the - middleware will raise a StaleBlockchain exception. - """ - if allowable_delay <= 0: - raise ValueError( - "You must set a positive allowable_delay in seconds for this middleware" - ) - - def stalecheck_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - cache: Dict[str, Optional[BlockData]] = {"latest": None} - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method not in skip_stalecheck_for_methods: - if not _is_fresh(cache["latest"], allowable_delay): - latest = w3.eth.get_block("latest") - if _is_fresh(latest, allowable_delay): - cache["latest"] = latest - else: - raise StaleBlockchain(latest, allowable_delay) - - return make_request(method, params) - - return middleware - - return stalecheck_middleware - - -# -- async -- # - - -async def async_make_stalecheck_middleware( - allowable_delay: int, - skip_stalecheck_for_methods: Collection[str] = SKIP_STALECHECK_FOR_METHODS, -) -> AsyncMiddleware: - """ - Use to require that a function will run only of the blockchain is recently updated. - - This middleware takes an argument, so unlike other middleware, you must make the - middleware with a method call. - - For example: `async_make_stalecheck_middleware(60*5)` - - If the latest block in the chain is older than 5 minutes in this example, then the - middleware will raise a StaleBlockchain exception. - """ - if allowable_delay <= 0: - raise ValueError( - "You must set a positive allowable_delay in seconds for this middleware" - ) - - async def stalecheck_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "AsyncWeb3" - ) -> AsyncMiddlewareCoroutine: - cache: Dict[str, Optional[BlockData]] = {"latest": None} - - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method not in skip_stalecheck_for_methods: - if not _is_fresh(cache["latest"], allowable_delay): - latest = await w3.eth.get_block("latest") - if _is_fresh(latest, allowable_delay): - cache["latest"] = latest - else: - raise StaleBlockchain(latest, allowable_delay) - - return await make_request(method, params) - - return middleware - - return stalecheck_middleware +class StaleCheckMiddleware(Web3Middleware): + def __init__( + self, + allowable_delay: int, + skip_stalecheck_for_methods: Collection[str] = SKIP_STALECHECK_FOR_METHODS, + ) -> None: + if allowable_delay <= 0: + raise ValueError( + "You must set a positive allowable_delay in seconds for this middleware" + ) + + self.allowable_delay = allowable_delay + self.skip_stalecheck_for_methods = skip_stalecheck_for_methods + self.cache: Dict[str, Optional[BlockData]] = {"latest": None} + + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if method not in self.skip_stalecheck_for_methods: + if not _is_fresh(self.cache["latest"], self.allowable_delay): + latest = self._w3.eth.get_block("latest") + if _is_fresh(latest, self.allowable_delay): + self.cache["latest"] = latest + else: + raise StaleBlockchain(latest, self.allowable_delay) + + return params + + # -- async -- # + + async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if method not in self.skip_stalecheck_for_methods: + if not _is_fresh(self.cache["latest"], self.allowable_delay): + latest = await self._w3.eth.get_block("latest") + if _is_fresh(latest, self.allowable_delay): + self.cache["latest"] = latest + else: + raise StaleBlockchain(latest, self.allowable_delay) + + return params diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py index 1a0d1ac73c..76599d8d6e 100644 --- a/web3/providers/async_base.py +++ b/web3/providers/async_base.py @@ -4,6 +4,7 @@ Any, Callable, Coroutine, + Set, Tuple, cast, ) @@ -21,9 +22,13 @@ from web3.exceptions import ( ProviderConnectionError, ) +from web3.middleware import ( + async_combine_middlewares, +) from web3.middleware.base import Web3Middleware from web3.types import ( AsyncMiddleware, + AsyncMiddlewareOnion, MiddlewareOnion, RPCEndpoint, RPCResponse, @@ -38,13 +43,9 @@ class AsyncBaseProvider: _middlewares: Tuple[Web3Middleware, ...] = () - # a tuple of (all_middlewares, request_func) _request_func_cache: Tuple[ - Tuple[AsyncMiddleware, ...], Callable[..., Coroutine[Any, Any, RPCResponse]] - ] = ( - None, - None, - ) + Tuple[Web3Middleware, ...], Callable[..., Coroutine[Any, Any, RPCResponse]] + ] = (None, None) is_async = True has_persistent_connection = False @@ -60,6 +61,24 @@ def middlewares(self, values: MiddlewareOnion) -> None: # tuple(values) converts to MiddlewareOnion -> Tuple[Middleware, ...] self._middlewares = tuple(values) # type: ignore + async def request_func( + self, async_w3: "AsyncWeb3", outer_middlewares: AsyncMiddlewareOnion + ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: + # type ignored b/c tuple(MiddlewareOnion) converts to tuple of middlewares + all_middlewares: Tuple[Web3Middleware] = tuple(outer_middlewares) + tuple(self.middlewares) # type: ignore # noqa: E501 + + cache_key = self._request_func_cache[0] + if cache_key != all_middlewares: + self._request_func_cache = ( + all_middlewares, + await async_combine_middlewares( + middlewares=all_middlewares, + async_w3=async_w3, + provider_request_fn=self.make_request, + ), + ) + return self._request_func_cache[-1] + async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: raise NotImplementedError("Providers must implement this method") diff --git a/web3/providers/base.py b/web3/providers/base.py index b88740db3c..34fc83e883 100644 --- a/web3/providers/base.py +++ b/web3/providers/base.py @@ -3,7 +3,6 @@ TYPE_CHECKING, Any, Callable, - Sequence, Tuple, cast, ) @@ -20,7 +19,10 @@ from web3.exceptions import ( ProviderConnectionError, ) -from web3.middleware.base import Web3Middleware +from web3.middleware import ( + Web3Middleware, + combine_middlewares, +) from web3.types import ( Middleware, MiddlewareOnion, @@ -33,12 +35,10 @@ class BaseProvider: - _middlewares: Tuple[Web3Middleware, ...] = () - # a tuple of (all_middlewares, request_func) - _request_func_cache: Tuple[Tuple[Middleware, ...], Callable[..., RPCResponse]] = ( - None, - None, - ) + _middlewares: Tuple[Middleware, ...] = () + _request_func_cache: Tuple[ + Tuple[Web3Middleware, ...], Callable[..., RPCResponse] + ] = (None, None) is_async = False has_persistent_connection = False @@ -54,6 +54,32 @@ def middlewares(self, values: MiddlewareOnion) -> None: # tuple(values) converts to MiddlewareOnion -> Tuple[Middleware, ...] self._middlewares = tuple(values) # type: ignore + def request_func( + self, w3: "Web3", outer_middlewares: MiddlewareOnion + ) -> Callable[..., RPCResponse]: + """ + @param w3 is the web3 instance + @param outer_middlewares is an iterable of middlewares, + ordered by first to execute + @returns a function that calls all the middleware and + eventually self.make_request() + """ + # type ignored b/c tuple(MiddlewareOnion) converts to tuple of middlewares + all_middlewares: Tuple[Web3Middleware] = tuple(outer_middlewares) + tuple(self.middlewares) # type: ignore # noqa: E501 + + cache_key = self._request_func_cache[0] + if cache_key != all_middlewares: + self._request_func_cache = ( + all_middlewares, + combine_middlewares( + middlewares=all_middlewares, + w3=w3, + provider_request_fn=self.make_request, + ), + ) + + return self._request_func_cache[-1] + def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: raise NotImplementedError("Providers must implement this method") diff --git a/web3/providers/eth_tester/main.py b/web3/providers/eth_tester/main.py index 8952dbb746..ab3bc681f7 100644 --- a/web3/providers/eth_tester/main.py +++ b/web3/providers/eth_tester/main.py @@ -24,7 +24,7 @@ attrdict_middleware, ) from web3.middleware.buffered_gas_estimate import ( - async_buffered_gas_estimate_middleware, + buffered_gas_estimate_middleware, ) from web3.providers import ( BaseProvider, @@ -51,8 +51,9 @@ class AsyncEthereumTesterProvider(AsyncBaseProvider): middlewares = ( attrdict_middleware, - # async_default_transaction_fields_middleware, - # async_ethereum_tester_middleware, + buffered_gas_estimate_middleware, + default_transaction_fields_middleware, + ethereum_tester_middleware, ) def __init__(self) -> None: diff --git a/web3/providers/eth_tester/middleware.py b/web3/providers/eth_tester/middleware.py index f51a122cca..6beb3385d5 100644 --- a/web3/providers/eth_tester/middleware.py +++ b/web3/providers/eth_tester/middleware.py @@ -40,14 +40,14 @@ from web3._utils.method_formatters import ( apply_list_to_array_formatter, ) +from web3.middleware.base import ( + Web3Middleware, +) from web3.middleware.formatting import ( FormattingMiddleware, ) from web3.types import ( - AsyncMiddlewareCoroutine, - Middleware, RPCEndpoint, - RPCResponse, TxParams, ) @@ -313,11 +313,6 @@ def is_hexstr(value: Any) -> bool: } -ethereum_tester_middleware = FormattingMiddleware( - request_formatters=request_formatters, result_formatters=result_formatters -) - - def guess_from(w3: "Web3", _: TxParams) -> ChecksumAddress: if w3.eth.coinbase: return w3.eth.coinbase @@ -339,40 +334,9 @@ def fill_default( return assoc(transaction, field, guess_val) -def default_transaction_fields_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" -) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in ( - "eth_call", - "eth_estimateGas", - "eth_sendTransaction", - "eth_createAccessList", - ): - fill_default_from = fill_default("from", guess_from, w3) - filled_transaction = pipe( - params[0], - fill_default_from, - ) - return make_request(method, [filled_transaction] + list(params)[1:]) - else: - return make_request(method, params) - - return middleware - - # --- async --- # -# async def async_ethereum_tester_middleware( # type: ignore -# make_request, web3: "AsyncWeb3" -# ) -> Middleware: -# middleware = await async_construct_formatting_middleware( -# request_formatters=request_formatters, result_formatters=result_formatters -# ) -# return await middleware(make_request, web3) - - async def async_guess_from( async_w3: "AsyncWeb3", _: TxParams ) -> Optional[ChecksumAddress]: @@ -400,20 +364,43 @@ async def async_fill_default( return assoc(transaction, field, guess_val) -async def async_default_transaction_fields_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" -) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: +# --- define middleware --- # + + +class DefaultTransactionFieldsMiddleware(Web3Middleware): + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: + if method in ( + "eth_call", + "eth_estimateGas", + "eth_sendTransaction", + "eth_createAccessList", + ): + fill_default_from = fill_default("from", guess_from, self._w3) + filled_transaction = pipe( + params[0], + fill_default_from, + ) + params = [filled_transaction] + list(params)[1:] + return params + + # --- async --- # + + async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method in ( "eth_call", "eth_estimateGas", "eth_sendTransaction", + "eth_createAccessList", ): filled_transaction = await async_fill_default( - "from", async_guess_from, async_w3, params[0] + "from", async_guess_from, self._w3, params[0] ) - return await make_request(method, [filled_transaction] + list(params)[1:]) - else: - return await make_request(method, params) + params = [filled_transaction] + list(params)[1:] - return middleware + return params + + +ethereum_tester_middleware = FormattingMiddleware( + request_formatters=request_formatters, result_formatters=result_formatters +) +default_transaction_fields_middleware = DefaultTransactionFieldsMiddleware() From 30351ead98b63a665339de7e965dc60a232e1c84 Mon Sep 17 00:00:00 2001 From: fselmo Date: Sat, 18 Nov 2023 17:18:59 +0300 Subject: [PATCH 03/24] Refactor caching as a decorator on ``make_request()`` - Remove all caching middleware. - Refactor request caching via configurations on the provider classes and decorators on their respective ``make_request()`` methods. --- web3/_utils/caching.py | 67 ++++ web3/middleware/__init__.py | 12 - web3/middleware/async_cache.py | 99 ------ web3/middleware/cache.py | 388 ----------------------- web3/providers/async_base.py | 24 +- web3/providers/base.py | 24 ++ web3/providers/rpc/async_rpc.py | 4 + web3/providers/rpc/rpc.py | 4 + web3/providers/websocket/websocket_v2.py | 18 ++ 9 files changed, 140 insertions(+), 500 deletions(-) delete mode 100644 web3/middleware/async_cache.py diff --git a/web3/_utils/caching.py b/web3/_utils/caching.py index e1339d6085..7af5235f1c 100644 --- a/web3/_utils/caching.py +++ b/web3/_utils/caching.py @@ -1,4 +1,5 @@ import collections +import copy import hashlib from typing import ( TYPE_CHECKING, @@ -8,6 +9,8 @@ Tuple, ) +from toolz import merge + from eth_utils import ( is_boolean, is_bytes, @@ -19,7 +22,11 @@ to_bytes, ) + if TYPE_CHECKING: + from web3.providers import ( # noqa: F401 + BaseProvider, + ) from web3.types import ( RPCEndpoint, ) @@ -58,3 +65,63 @@ def __init__( self.response_formatters = response_formatters self.subscription_id = subscription_id self.middleware_response_processors: List[Callable[..., Any]] = [] + + +def is_cacheable_request(provider: "BaseProvider", method: "RPCEndpoint") -> bool: + if provider._cache_allowed_requests and method in provider._cacheable_requests: + return True + + +# -- request caching decorators -- # + + +def handle_request_caching(func): + def wrapper(*args, **kwargs): + # args=(self, method, params) - where "self" should be BaseProvider instance + provider, method, params = args + + if is_cacheable_request(provider, method): + request_cache = provider._request_cache + cache_key = generate_cache_key((method, params, kwargs)) + cache_result = request_cache.get_cache_entry(cache_key) + if cache_result is not None: + return cache_result + else: + response = func(*args, **kwargs) + request_cache.cache(cache_key, response) + return response + else: + return func(*args, **kwargs) + + return wrapper + + +def async_handle_request_caching(func): + async def wrapper(*args, **kwargs): + # args=(self, method, params) - where "self" should be the provider instance + provider, method, params = args + + if is_cacheable_request(provider, method): + if provider.has_persistent_connection: + next_request_id = provider._effective_request_number + 1 + else: + next_request_id = copy.deepcopy(next(provider.request_counter)) + + request_cache = provider._request_cache + cache_key = generate_cache_key((method, params, kwargs)) + cache_result = request_cache.get_cache_entry(cache_key) + if cache_result is not None: + # Increment request counter. This both makes the request unique, + # independent of where the response came from, and ensures that + # calculated request and response id matching can happen for + # persistent connection providers. + next(provider.request_counter) + return merge(cache_result, {"id": next_request_id}) + else: + response = await func(*args, **kwargs) + request_cache.cache(cache_key, response) + return response + else: + return await func(*args, **kwargs) + + return wrapper diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index 0e57026cde..198eed1e29 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -15,10 +15,6 @@ from .abi import ( abi_middleware, ) -from .async_cache import ( - _async_simple_cache_middleware as async_simple_cache_middleware, - async_construct_simple_cache_middleware, -) from .attrdict import ( attrdict_middleware, ) @@ -28,14 +24,6 @@ from .buffered_gas_estimate import ( buffered_gas_estimate_middleware, ) -from .cache import ( - _latest_block_based_cache_middleware as latest_block_based_cache_middleware, - _simple_cache_middleware as simple_cache_middleware, - _time_based_cache_middleware as time_based_cache_middleware, - construct_latest_block_based_cache_middleware, - construct_simple_cache_middleware, - construct_time_based_cache_middleware, -) from .gas_price_strategy import ( gas_price_strategy_middleware, ) diff --git a/web3/middleware/async_cache.py b/web3/middleware/async_cache.py deleted file mode 100644 index ab354cdd58..0000000000 --- a/web3/middleware/async_cache.py +++ /dev/null @@ -1,99 +0,0 @@ -from concurrent.futures import ( - ThreadPoolExecutor, -) -import threading -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Collection, -) - -from web3._utils.async_caching import ( - async_lock, -) -from web3._utils.caching import ( - generate_cache_key, -) -from web3.middleware.cache import ( - SIMPLE_CACHE_RPC_WHITELIST, - _should_cache_response, -) -from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareCoroutine, - Middleware, - RPCEndpoint, - RPCResponse, -) -from web3.utils.caching import ( - SimpleCache, -) - -if TYPE_CHECKING: - from web3 import ( # noqa: F401 - AsyncWeb3, - Web3, - ) - -_async_request_thread_pool = ThreadPoolExecutor() - - -async def async_construct_simple_cache_middleware( - cache: SimpleCache = None, - rpc_whitelist: Collection[RPCEndpoint] = SIMPLE_CACHE_RPC_WHITELIST, - should_cache_fn: Callable[ - [RPCEndpoint, Any, RPCResponse], bool - ] = _should_cache_response, -) -> AsyncMiddleware: - """ - Constructs a middleware which caches responses based on the request - ``method`` and ``params`` - - :param cache: A ``SimpleCache`` class. - :param rpc_whitelist: A set of RPC methods which may have their responses cached. - :param should_cache_fn: A callable which accepts ``method`` ``params`` and - ``response`` and returns a boolean as to whether the response should be - cached. - """ - - async def async_simple_cache_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], _async_w3: "AsyncWeb3" - ) -> AsyncMiddlewareCoroutine: - lock = threading.Lock() - - # It's not imperative that we define ``_cache`` here rather than in - # ``async_construct_simple_cache_middleware``. Due to the nature of async, - # construction is awaited and doesn't happen at import. This means separate - # instances would still get unique caches. However, to keep the code consistent - # with the synchronous version, and provide less ambiguity, we define it - # similarly to the synchronous version here. - _cache = cache if cache else SimpleCache(256) - - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in rpc_whitelist: - cache_key = generate_cache_key( - f"{threading.get_ident()}:{(method, params)}" - ) - cached_request = _cache.get_cache_entry(cache_key) - if cached_request is not None: - return cached_request - - response = await make_request(method, params) - if should_cache_fn(method, params, response): - async with async_lock(_async_request_thread_pool, lock): - _cache.cache(cache_key, response) - return response - else: - return await make_request(method, params) - - return middleware - - return async_simple_cache_middleware - - -async def _async_simple_cache_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" -) -> Middleware: - middleware = await async_construct_simple_cache_middleware() - return await middleware(make_request, async_w3) diff --git a/web3/middleware/cache.py b/web3/middleware/cache.py index 39e871f6a2..e69de29bb2 100644 --- a/web3/middleware/cache.py +++ b/web3/middleware/cache.py @@ -1,388 +0,0 @@ -import functools -import threading -import time -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Collection, - Dict, - Set, - cast, -) - -from eth_utils import ( - is_list_like, -) -import lru - -from web3._utils.caching import ( - generate_cache_key, -) -from web3._utils.compat import ( - Literal, - TypedDict, -) -from web3.middleware.base import Web3Middleware -from web3.types import ( - BlockData, - BlockNumber, - Middleware, - RPCEndpoint, - RPCResponse, -) -from web3.utils.caching import ( - SimpleCache, -) - -if TYPE_CHECKING: - from web3 import Web3 # noqa: F401 - -SIMPLE_CACHE_RPC_WHITELIST = cast( - Set[RPCEndpoint], - ( - "web3_clientVersion", - "net_version", - "eth_getBlockTransactionCountByHash", - "eth_getUncleCountByBlockHash", - "eth_getBlockByHash", - "eth_getTransactionByHash", - "eth_getTransactionByBlockHashAndIndex", - "eth_getRawTransactionByHash", - "eth_getUncleByBlockHashAndIndex", - "eth_chainId", - ), -) - - -def _should_cache_response( - _method: RPCEndpoint, _params: Any, response: RPCResponse -) -> bool: - return ( - "error" not in response - and "result" in response - and response["result"] is not None - ) - - -def construct_simple_cache_middleware( - cache: SimpleCache = None, - rpc_whitelist: Collection[RPCEndpoint] = None, - should_cache_fn: Callable[ - [RPCEndpoint, Any, RPCResponse], bool - ] = _should_cache_response, -) -> Middleware: - """ - Constructs a middleware which caches responses based on the request - ``method`` and ``params`` - - :param cache: A ``SimpleCache`` class. - :param rpc_whitelist: A set of RPC methods which may have their responses cached. - :param should_cache_fn: A callable which accepts ``method`` ``params`` and - ``response`` and returns a boolean as to whether the response should be - cached. - """ - if rpc_whitelist is None: - rpc_whitelist = SIMPLE_CACHE_RPC_WHITELIST - - def simple_cache_middleware( - make_request: Callable[[RPCEndpoint, Any], RPCResponse], _w3: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - lock = threading.Lock() - - # Setting the cache here, rather than in ``construct_simple_cache_middleware``, - # ensures that each instance of the middleware has its own cache. This is - # important for compatibility with multiple ``Web3`` instances. - _cache = cache if cache else SimpleCache(256) - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in rpc_whitelist: - cache_key = generate_cache_key( - f"{threading.get_ident()}:{(method, params)}" - ) - cached_request = _cache.get_cache_entry(cache_key) - if cached_request is not None: - return cached_request - - response = make_request(method, params) - if should_cache_fn(method, params, response): - if lock.acquire(blocking=False): - try: - _cache.cache(cache_key, response) - finally: - lock.release() - return response - else: - return make_request(method, params) - - return middleware - - return simple_cache_middleware - - -_simple_cache_middleware = construct_simple_cache_middleware() - - -class SimpleCacheMiddleware(Web3Middleware): - """ - Constructs a middleware which caches responses based on the request - ``method`` and ``params`` - - :param cache: A ``SimpleCache`` class. - :param rpc_whitelist: A set of RPC methods which may have their responses cached. - :param should_cache_fn: A callable which accepts ``method`` ``params`` and - ``response`` and returns a boolean as to whether the response should be - cached. - """ - - -TIME_BASED_CACHE_RPC_WHITELIST = cast( - Set[RPCEndpoint], - { - "eth_coinbase", - "eth_accounts", - }, -) - - -def construct_time_based_cache_middleware( - cache_class: Callable[..., Dict[Any, Any]], - cache_expire_seconds: int = 15, - rpc_whitelist: Collection[RPCEndpoint] = TIME_BASED_CACHE_RPC_WHITELIST, - should_cache_fn: Callable[ - [RPCEndpoint, Any, RPCResponse], bool - ] = _should_cache_response, -) -> Middleware: - """ - Constructs a middleware which caches responses based on the request - ``method`` and ``params`` for a maximum amount of time as specified - - :param cache_class: Any dictionary-like object - :param cache_expire_seconds: The number of seconds an item may be cached - before it should expire. - :param rpc_whitelist: A set of RPC methods which may have their responses cached. - :param should_cache_fn: A callable which accepts ``method`` ``params`` and - ``response`` and returns a boolean as to whether the response should be - cached. - """ - - def time_based_cache_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - cache = cache_class() - lock = threading.Lock() - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - lock_acquired = ( - lock.acquire(blocking=False) if method in rpc_whitelist else False - ) - - try: - if lock_acquired and method in rpc_whitelist: - cache_key = generate_cache_key((method, params)) - if cache_key in cache: - # check that the cached response is not expired. - cached_at, cached_response = cache[cache_key] - cached_for = time.time() - cached_at - - if cached_for <= cache_expire_seconds: - return cached_response - else: - del cache[cache_key] - - # cache either missed or expired so make the request. - response = make_request(method, params) - - if should_cache_fn(method, params, response): - cache[cache_key] = (time.time(), response) - - return response - else: - return make_request(method, params) - finally: - if lock_acquired: - lock.release() - - return middleware - - return time_based_cache_middleware - - -_time_based_cache_middleware = construct_time_based_cache_middleware( - cache_class=functools.partial(lru.LRU, 256), -) - - -BLOCK_NUMBER_RPC_WHITELIST = cast( - Set[RPCEndpoint], - { - "eth_gasPrice", - "eth_blockNumber", - "eth_getBalance", - "eth_getStorageAt", - "eth_getTransactionCount", - "eth_getBlockTransactionCountByNumber", - "eth_getUncleCountByBlockNumber", - "eth_getCode", - "eth_call", - "eth_createAccessList", - "eth_estimateGas", - "eth_getBlockByNumber", - "eth_getTransactionByBlockNumberAndIndex", - "eth_getTransactionReceipt", - "eth_getUncleByBlockNumberAndIndex", - "eth_getLogs", - }, -) - -AVG_BLOCK_TIME_KEY: Literal["avg_block_time"] = "avg_block_time" -AVG_BLOCK_SAMPLE_SIZE_KEY: Literal["avg_block_sample_size"] = "avg_block_sample_size" -AVG_BLOCK_TIME_UPDATED_AT_KEY: Literal[ - "avg_block_time_updated_at" -] = "avg_block_time_updated_at" - - -def _is_latest_block_number_request(method: RPCEndpoint, params: Any) -> bool: - if method != "eth_getBlockByNumber": - return False - elif is_list_like(params) and tuple(params[:1]) == ("latest",): - return True - return False - - -BlockInfoCache = TypedDict( - "BlockInfoCache", - { - "avg_block_time": float, - "avg_block_sample_size": int, - "avg_block_time_updated_at": float, - "latest_block": BlockData, - }, - total=False, -) - - -def construct_latest_block_based_cache_middleware( - cache_class: Callable[..., Dict[Any, Any]], - rpc_whitelist: Collection[RPCEndpoint] = BLOCK_NUMBER_RPC_WHITELIST, - average_block_time_sample_size: int = 240, - default_average_block_time: int = 15, - should_cache_fn: Callable[ - [RPCEndpoint, Any, RPCResponse], bool - ] = _should_cache_response, -) -> Middleware: - """ - Constructs a middleware which caches responses based on the request - ``method``, ``params``, and the current latest block hash. - - :param cache_class: Any dictionary-like object - :param rpc_whitelist: A set of RPC methods which may have their responses cached. - :param average_block_time_sample_size: number of blocks to look back when computing - average block time - :param default_average_block_time: estimated number of seconds per block - :param should_cache_fn: A callable which accepts ``method`` ``params`` and - ``response`` and returns a boolean as to whether the response should be - cached. - - .. note:: - This middleware avoids re-fetching the current latest block for each - request by tracking the current average block time and only requesting - a new block when the last seen latest block is older than the average - block time. - """ - - def latest_block_based_cache_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - cache = cache_class() - block_info: BlockInfoCache = {} - - def _update_block_info_cache() -> None: - avg_block_time = block_info.get( - AVG_BLOCK_TIME_KEY, default_average_block_time - ) - avg_block_sample_size = block_info.get(AVG_BLOCK_SAMPLE_SIZE_KEY, 0) - avg_block_time_updated_at = block_info.get(AVG_BLOCK_TIME_UPDATED_AT_KEY, 0) - - # compute age as counted by number of blocks since the avg_block_time - if avg_block_time == 0: - avg_block_time_age_in_blocks: float = avg_block_sample_size - else: - avg_block_time_age_in_blocks = ( - time.time() - avg_block_time_updated_at - ) / avg_block_time - - if avg_block_time_age_in_blocks >= avg_block_sample_size: - # If the length of time since the average block time as - # measured by blocks is greater than or equal to the number of - # blocks sampled then we need to recompute the average block - # time. - latest_block = w3.eth.get_block("latest") - ancestor_block_number = BlockNumber( - max( - 0, - latest_block["number"] - average_block_time_sample_size, - ) - ) - ancestor_block = w3.eth.get_block(ancestor_block_number) - sample_size = latest_block["number"] - ancestor_block_number - - block_info[AVG_BLOCK_SAMPLE_SIZE_KEY] = sample_size - if sample_size != 0: - block_info[AVG_BLOCK_TIME_KEY] = ( - latest_block["timestamp"] - ancestor_block["timestamp"] - ) / sample_size - else: - block_info[AVG_BLOCK_TIME_KEY] = avg_block_time - block_info[AVG_BLOCK_TIME_UPDATED_AT_KEY] = time.time() - - if "latest_block" in block_info: - latest_block = block_info["latest_block"] - time_since_latest_block = time.time() - latest_block["timestamp"] - - # latest block is too old so update cache - if time_since_latest_block > avg_block_time: - block_info["latest_block"] = w3.eth.get_block("latest") - else: - # latest block has not been fetched so we fetch it. - block_info["latest_block"] = w3.eth.get_block("latest") - - lock = threading.Lock() - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - lock_acquired = ( - lock.acquire(blocking=False) if method in rpc_whitelist else False - ) - - try: - should_try_cache = ( - lock_acquired - and method in rpc_whitelist - and not _is_latest_block_number_request(method, params) - ) - if should_try_cache: - _update_block_info_cache() - latest_block_hash = block_info["latest_block"]["hash"] - cache_key = generate_cache_key((latest_block_hash, method, params)) - if cache_key in cache: - return cache[cache_key] - - response = make_request(method, params) - if should_cache_fn(method, params, response): - cache[cache_key] = response - return response - else: - return make_request(method, params) - finally: - if lock_acquired: - lock.release() - - return middleware - - return latest_block_based_cache_middleware - - -_latest_block_based_cache_middleware = construct_latest_block_based_cache_middleware( - cache_class=functools.partial(lru.LRU, 256), - rpc_whitelist=BLOCK_NUMBER_RPC_WHITELIST, -) diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py index 76599d8d6e..2ba9d796ca 100644 --- a/web3/providers/async_base.py +++ b/web3/providers/async_base.py @@ -27,12 +27,13 @@ ) from web3.middleware.base import Web3Middleware from web3.types import ( - AsyncMiddleware, AsyncMiddlewareOnion, MiddlewareOnion, RPCEndpoint, RPCResponse, ) +from web3.utils import SimpleCache + if TYPE_CHECKING: from web3 import ( # noqa: F401 @@ -41,12 +42,33 @@ ) +CACHEABLE_REQUESTS = cast( + Set[RPCEndpoint], + ( + "eth_chainId", + "eth_getBlockByHash", + "eth_getBlockTransactionCountByHash", + "eth_getRawTransactionByHash", + "eth_getTransactionByBlockHashAndIndex", + "eth_getTransactionByHash", + "eth_getUncleByBlockHashAndIndex", + "eth_getUncleCountByBlockHash", + "net_version", + "web3_clientVersion", + ), +) + + class AsyncBaseProvider: _middlewares: Tuple[Web3Middleware, ...] = () _request_func_cache: Tuple[ Tuple[Web3Middleware, ...], Callable[..., Coroutine[Any, Any, RPCResponse]] ] = (None, None) + _request_cache: SimpleCache = SimpleCache(size=500) + _cache_allowed_requests: bool = True + _cacheable_requests: Set[RPCEndpoint] = CACHEABLE_REQUESTS + is_async = True has_persistent_connection = False global_ccip_read_enabled: bool = True diff --git a/web3/providers/base.py b/web3/providers/base.py index 34fc83e883..ebd2b19dbe 100644 --- a/web3/providers/base.py +++ b/web3/providers/base.py @@ -3,6 +3,7 @@ TYPE_CHECKING, Any, Callable, + Set, Tuple, cast, ) @@ -29,17 +30,40 @@ RPCEndpoint, RPCResponse, ) +from web3.utils import SimpleCache + if TYPE_CHECKING: from web3 import Web3 # noqa: F401 +CACHEABLE_REQUESTS = cast( + Set[RPCEndpoint], + ( + "eth_chainId", + "eth_getBlockByHash", + "eth_getBlockTransactionCountByHash", + "eth_getRawTransactionByHash", + "eth_getTransactionByBlockHashAndIndex", + "eth_getTransactionByHash", + "eth_getUncleByBlockHashAndIndex", + "eth_getUncleCountByBlockHash", + "net_version", + "web3_clientVersion", + ), +) + + class BaseProvider: _middlewares: Tuple[Middleware, ...] = () _request_func_cache: Tuple[ Tuple[Web3Middleware, ...], Callable[..., RPCResponse] ] = (None, None) + _request_cache: SimpleCache = SimpleCache(size=500) + _cache_allowed_requests: bool = True + _cacheable_requests: Set[RPCEndpoint] = CACHEABLE_REQUESTS + is_async = False has_persistent_connection = False global_ccip_read_enabled: bool = True diff --git a/web3/providers/rpc/async_rpc.py b/web3/providers/rpc/async_rpc.py index 95915ad1fa..51f974ab04 100644 --- a/web3/providers/rpc/async_rpc.py +++ b/web3/providers/rpc/async_rpc.py @@ -32,6 +32,9 @@ RPCEndpoint, RPCResponse, ) +from ..._utils.caching import ( + async_handle_request_caching, +) from ...datastructures import ( NamedElementOnion, @@ -121,6 +124,7 @@ async def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes self.endpoint_uri, request_data, **self.get_request_kwargs() ) + @async_handle_request_caching async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: self.logger.debug( f"Making request HTTP. URI: {self.endpoint_uri}, Method: {method}" diff --git a/web3/providers/rpc/rpc.py b/web3/providers/rpc/rpc.py index 70f26af700..abd4243b0e 100644 --- a/web3/providers/rpc/rpc.py +++ b/web3/providers/rpc/rpc.py @@ -40,6 +40,9 @@ check_if_retry_on_failure, ExceptionRetryConfiguration, ) +from ..._utils.caching import ( + handle_request_caching, +) class HTTPProvider(JSONBaseProvider): @@ -115,6 +118,7 @@ def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes: self.endpoint_uri, request_data, **self.get_request_kwargs() ) + @handle_request_caching def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: self.logger.debug( f"Making request HTTP. URI: {self.endpoint_uri}, Method: {method}" diff --git a/web3/providers/websocket/websocket_v2.py b/web3/providers/websocket/websocket_v2.py index 949c6b0c68..09f4e2ce31 100644 --- a/web3/providers/websocket/websocket_v2.py +++ b/web3/providers/websocket/websocket_v2.py @@ -1,4 +1,5 @@ import asyncio +import copy import json import logging import os @@ -23,6 +24,7 @@ ) from web3._utils.caching import ( + async_handle_request_caching, generate_cache_key, ) from web3.exceptions import ( @@ -98,6 +100,21 @@ def __init__( def __str__(self) -> str: return f"Websocket connection: {self.endpoint_uri}" + @property + def _effective_request_number(self) -> int: + current_request_id = next(copy.deepcopy(self.request_counter)) - 1 + cache_key = generate_cache_key(current_request_id) + requests_info = self._request_processor._request_information_cache._data + + if cache_key not in requests_info: + return current_request_id + + while cache_key in requests_info: + current_request_id += 1 + cache_key = generate_cache_key(current_request_id) + + return current_request_id + async def is_connected(self, show_traceback: bool = False) -> bool: if not self._ws: return False @@ -147,6 +164,7 @@ async def disconnect(self) -> None: self._request_processor.clear_caches() + @async_handle_request_caching async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: request_data = self.encode_rpc_request(method, params) From 1c07cce23ebb888034eb12ed3156201ef508a36b Mon Sep 17 00:00:00 2001 From: fselmo Date: Tue, 21 Nov 2023 09:47:07 -0700 Subject: [PATCH 04/24] Remove middleware dependent on old logic - This middleware depends on a request always being made when middleware are processed. This isn't how the middleware is processed for websockets and the separation of middleware into request param processing and response processing further helps with being able to support batch requests. Remove all middleware that were dependent on a request always being sandwiched between request processing and response processing. --- .../middleware/test_fixture_middleware.py | 96 ----- ...est_latest_block_based_cache_middleware.py | 287 -------------- .../test_simple_cache_middleware.py | 352 ------------------ .../test_time_based_cache_middleware.py | 168 --------- web3/middleware/exception_handling.py | 49 --- web3/middleware/fixture.py | 190 ---------- .../simulate_unmined_transaction.py | 43 --- 7 files changed, 1185 deletions(-) delete mode 100644 tests/core/middleware/test_fixture_middleware.py delete mode 100644 tests/core/middleware/test_latest_block_based_cache_middleware.py delete mode 100644 tests/core/middleware/test_simple_cache_middleware.py delete mode 100644 tests/core/middleware/test_time_based_cache_middleware.py delete mode 100644 web3/middleware/exception_handling.py delete mode 100644 web3/middleware/fixture.py delete mode 100644 web3/middleware/simulate_unmined_transaction.py diff --git a/tests/core/middleware/test_fixture_middleware.py b/tests/core/middleware/test_fixture_middleware.py deleted file mode 100644 index 0c73baa469..0000000000 --- a/tests/core/middleware/test_fixture_middleware.py +++ /dev/null @@ -1,96 +0,0 @@ -import pytest - -from web3 import ( - Web3, -) -from web3.middleware import ( - construct_error_generator_middleware, - construct_fixture_middleware, - construct_result_generator_middleware, -) -from web3.providers.base import ( - BaseProvider, -) - - -class DummyProvider(BaseProvider): - def make_request(self, method, params): - raise NotImplementedError(f"Cannot make request for {method}:{params}") - - -@pytest.fixture -def w3(): - return Web3(provider=DummyProvider(), middlewares=[]) - - -@pytest.mark.parametrize( - "method,expected", - ( - ("test_endpoint", "value-a"), - ("not_implemented", NotImplementedError), - ), -) -def test_fixture_middleware(w3, method, expected): - w3.middleware_onion.add(construct_fixture_middleware({"test_endpoint": "value-a"})) - - if isinstance(expected, type) and issubclass(expected, Exception): - with pytest.raises(expected): - w3.manager.request_blocking(method, []) - else: - actual = w3.manager.request_blocking(method, []) - assert actual == expected - - -@pytest.mark.parametrize( - "method,expected", - ( - ("test_endpoint", "value-a"), - ("not_implemented", NotImplementedError), - ), -) -def test_result_middleware(w3, method, expected): - def _callback(method, params): - return params[0] - - w3.middleware_onion.add( - construct_result_generator_middleware( - { - "test_endpoint": _callback, - } - ) - ) - - if isinstance(expected, type) and issubclass(expected, Exception): - with pytest.raises(expected): - w3.manager.request_blocking(method, [expected]) - else: - actual = w3.manager.request_blocking(method, [expected]) - assert actual == expected - - -@pytest.mark.parametrize( - "method,expected", - ( - ("test_endpoint", "value-a"), - ("not_implemented", NotImplementedError), - ), -) -def test_error_middleware(w3, method, expected): - def _callback(method, params): - return params[0] - - w3.middleware_onion.add( - construct_error_generator_middleware( - { - "test_endpoint": _callback, - } - ) - ) - - if isinstance(expected, type) and issubclass(expected, Exception): - with pytest.raises(expected): - w3.manager.request_blocking(method, [expected]) - else: - with pytest.raises(ValueError) as err: - w3.manager.request_blocking(method, [expected]) - assert expected in str(err) diff --git a/tests/core/middleware/test_latest_block_based_cache_middleware.py b/tests/core/middleware/test_latest_block_based_cache_middleware.py deleted file mode 100644 index 32f88333af..0000000000 --- a/tests/core/middleware/test_latest_block_based_cache_middleware.py +++ /dev/null @@ -1,287 +0,0 @@ -import codecs -import itertools -import pytest -import time -import uuid - -from eth_utils import ( - is_hex, - is_integer, - to_tuple, -) - -from web3 import ( - Web3, -) -from web3._utils.caching import ( - generate_cache_key, -) -from web3._utils.formatters import ( - hex_to_integer, -) -from web3.middleware import ( - construct_error_generator_middleware, - construct_latest_block_based_cache_middleware, - construct_result_generator_middleware, -) -from web3.providers.base import ( - BaseProvider, -) - - -@pytest.fixture -def w3_base(): - return Web3(provider=BaseProvider(), middlewares=[]) - - -def _mk_block(n, timestamp): - return { - "hash": codecs.decode(str(n).zfill(32), "hex"), - "number": n, - "timestamp": timestamp, - } - - -@to_tuple -def generate_block_history(num_mined_blocks=5, block_time=1): - genesis = _mk_block(0, time.time()) - yield genesis - for block_number in range(1, num_mined_blocks + 1): - yield _mk_block( - block_number, - genesis["timestamp"] + 2 * block_number, - ) - - -@pytest.fixture -def construct_block_data_middleware(): - def _construct_block_data_middleware(num_blocks): - blocks = generate_block_history(num_blocks) - _block_info = {"blocks": blocks, "head_block_number": blocks[0]["number"]} - - def _evm_mine(method, params, block_info=_block_info): - num_blocks = params[0] - head_block_number = block_info["head_block_number"] - if head_block_number + num_blocks >= len(block_info["blocks"]): - raise ValueError("no more blocks to mine") - - block_info["head_block_number"] += num_blocks - - def _get_block_by_number(method, params, block_info=_block_info): - block_id = params[0] - blocks = block_info["blocks"] - head_block_number = block_info["head_block_number"] - - if block_id == "latest": - return blocks[head_block_number] - elif block_id == "pending": - if head_block_number + 1 >= len(blocks): - raise ValueError("no pending block") - return blocks[head_block_number + 1] - elif block_id == "earliest": - return blocks[0] - elif is_integer(block_id): - if block_id <= head_block_number: - return blocks[block_id] - else: - return None - elif is_hex(block_id): - block_id = hex_to_integer(block_id) - if block_id <= head_block_number: - return blocks[block_id] - else: - return None - else: - raise TypeError("Invalid type for block_id") - - def _get_block_by_hash(method, params, block_info=_block_info): - block_hash = params[0] - blocks = block_info["blocks"] - head_block_number = block_info["head_block_number"] - - blocks_by_hash = {block["hash"]: block for block in blocks} - try: - block = blocks_by_hash[block_hash] - if block["number"] <= head_block_number: - return block - else: - return None - except KeyError: - return None - - return construct_result_generator_middleware( - { - "eth_getBlockByNumber": _get_block_by_number, - "eth_getBlockByHash": _get_block_by_hash, - "evm_mine": _evm_mine, - } - ) - - return _construct_block_data_middleware - - -@pytest.fixture -def block_data_middleware(construct_block_data_middleware): - return construct_block_data_middleware(5) - - -@pytest.fixture -def result_generator_middleware(): - return construct_result_generator_middleware( - { - "fake_endpoint": lambda *_: str(uuid.uuid4()), - "not_whitelisted": lambda *_: str(uuid.uuid4()), - } - ) - - -@pytest.fixture -def latest_block_based_cache_middleware(): - return construct_latest_block_based_cache_middleware( - cache_class=dict, - average_block_time_sample_size=1, - default_average_block_time=0.1, - rpc_whitelist={"fake_endpoint"}, - ) - - -@pytest.fixture -def w3( - w3_base, - result_generator_middleware, - block_data_middleware, - latest_block_based_cache_middleware, -): - w3_base.middleware_onion.add(block_data_middleware) - w3_base.middleware_onion.add(result_generator_middleware) - w3_base.middleware_onion.add(latest_block_based_cache_middleware) - return w3_base - - -def test_latest_block_based_cache_middleware_pulls_from_cache( - w3_base, block_data_middleware, result_generator_middleware -): - w3 = w3_base - w3.middleware_onion.add(block_data_middleware) - w3.middleware_onion.add(result_generator_middleware) - - current_block_hash = w3.eth.get_block("latest")["hash"] - - def cache_class(): - return { - generate_cache_key((current_block_hash, "fake_endpoint", [1])): { - "result": "value-a" - }, - } - - w3.middleware_onion.add( - construct_latest_block_based_cache_middleware( - cache_class=cache_class, - rpc_whitelist={"fake_endpoint"}, - ) - ) - - assert w3.manager.request_blocking("fake_endpoint", [1]) == "value-a" - - -def test_latest_block_based_cache_middleware_populates_cache(w3): - result = w3.manager.request_blocking("fake_endpoint", []) - - assert w3.manager.request_blocking("fake_endpoint", []) == result - assert w3.manager.request_blocking("fake_endpoint", [1]) != result - - -def test_latest_block_based_cache_middleware_busts_cache(w3, mocker): - result = w3.manager.request_blocking("fake_endpoint", []) - - assert w3.manager.request_blocking("fake_endpoint", []) == result - w3.testing.mine() - - # should still be cached for at least 1 second. This also verifies that - # the middleware caches the latest block based on the block time. - assert w3.manager.request_blocking("fake_endpoint", []) == result - - mocker.patch("time.time", return_value=time.time() + 5) - - assert w3.manager.request_blocking("fake_endpoint", []) != result - - -def test_latest_block_cache_middleware_does_not_cache_bad_responses( - w3_base, block_data_middleware, latest_block_based_cache_middleware -): - counter = itertools.count() - w3 = w3_base - - def result_cb(method, params): - next(counter) - return None - - w3 = w3_base - w3.middleware_onion.add(block_data_middleware) - w3.middleware_onion.add( - construct_result_generator_middleware( - { - "fake_endpoint": result_cb, - } - ) - ) - w3.middleware_onion.add(latest_block_based_cache_middleware) - - w3.manager.request_blocking("fake_endpoint", []) - w3.manager.request_blocking("fake_endpoint", []) - - assert next(counter) == 2 - - -def test_latest_block_cache_middleware_does_not_cache_error_response( - w3_base, block_data_middleware, latest_block_based_cache_middleware -): - counter = itertools.count() - w3 = w3_base - - def error_cb(method, params): - next(counter) - return "the error message" - - w3.middleware_onion.add(block_data_middleware) - w3.middleware_onion.add( - construct_error_generator_middleware( - { - "fake_endpoint": error_cb, - } - ) - ) - w3.middleware_onion.add(latest_block_based_cache_middleware) - - with pytest.raises(ValueError): - w3.manager.request_blocking("fake_endpoint", []) - with pytest.raises(ValueError): - w3.manager.request_blocking("fake_endpoint", []) - - assert next(counter) == 2 - - -def test_latest_block_cache_middleware_does_not_cache_get_latest_block( - w3_base, block_data_middleware, result_generator_middleware -): - w3 = w3_base - w3.middleware_onion.add(block_data_middleware) - w3.middleware_onion.add(result_generator_middleware) - - current_block_hash = w3.eth.get_block("latest")["hash"] - - def cache_class(): - return { - generate_cache_key( - (current_block_hash, "eth_getBlockByNumber", ["latest"]) - ): {"result": "value-a"}, - } - - w3.middleware_onion.add( - construct_latest_block_based_cache_middleware( - cache_class=cache_class, - rpc_whitelist={"eth_getBlockByNumber"}, - ) - ) - - assert w3.manager.request_blocking("eth_getBlockByNumber", ["latest"]) != "value-a" diff --git a/tests/core/middleware/test_simple_cache_middleware.py b/tests/core/middleware/test_simple_cache_middleware.py deleted file mode 100644 index 5a89e01496..0000000000 --- a/tests/core/middleware/test_simple_cache_middleware.py +++ /dev/null @@ -1,352 +0,0 @@ -import itertools -import pytest -import threading -import uuid - -from web3 import ( - AsyncWeb3, - Web3, -) -from web3._utils.caching import ( - generate_cache_key, -) -from web3.middleware import ( - async_simple_cache_middleware, - construct_error_generator_middleware, - construct_result_generator_middleware, - construct_simple_cache_middleware, - simple_cache_middleware, -) -from web3.middleware.async_cache import ( - async_construct_simple_cache_middleware, -) -from web3.middleware.fixture import ( - async_construct_error_generator_middleware, - async_construct_result_generator_middleware, -) -from web3.providers.base import ( - BaseProvider, -) -from web3.providers.eth_tester import ( - AsyncEthereumTesterProvider, -) -from web3.types import ( - RPCEndpoint, -) -from web3.utils.caching import ( - SimpleCache, -) - - -@pytest.fixture -def w3_base(): - return Web3(provider=BaseProvider(), middlewares=[]) - - -@pytest.fixture -def result_generator_middleware(): - return construct_result_generator_middleware( - { - RPCEndpoint("fake_endpoint"): lambda *_: str(uuid.uuid4()), - RPCEndpoint("not_whitelisted"): lambda *_: str(uuid.uuid4()), - } - ) - - -@pytest.fixture -def w3(w3_base, result_generator_middleware): - w3_base.middleware_onion.add(result_generator_middleware) - return w3_base - - -def simple_cache_return_value_a(): - _cache = SimpleCache() - _cache.cache( - generate_cache_key(f"{threading.get_ident()}:{('fake_endpoint', [1])}"), - {"result": "value-a"}, - ) - return _cache - - -def test_simple_cache_middleware_pulls_from_cache(w3): - w3.middleware_onion.add( - construct_simple_cache_middleware( - cache=simple_cache_return_value_a(), - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - ) - - assert w3.manager.request_blocking("fake_endpoint", [1]) == "value-a" - - -def test_simple_cache_middleware_populates_cache(w3): - w3.middleware_onion.add( - construct_simple_cache_middleware( - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - ) - - result = w3.manager.request_blocking("fake_endpoint", []) - - assert w3.manager.request_blocking("fake_endpoint", []) == result - assert w3.manager.request_blocking("fake_endpoint", [1]) != result - - -def test_simple_cache_middleware_does_not_cache_none_responses(w3_base): - counter = itertools.count() - w3 = w3_base - - def result_cb(_method, _params): - next(counter) - return None - - w3.middleware_onion.add( - construct_result_generator_middleware( - { - RPCEndpoint("fake_endpoint"): result_cb, - } - ) - ) - - w3.middleware_onion.add( - construct_simple_cache_middleware( - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - ) - - w3.manager.request_blocking("fake_endpoint", []) - w3.manager.request_blocking("fake_endpoint", []) - - assert next(counter) == 2 - - -def test_simple_cache_middleware_does_not_cache_error_responses(w3_base): - w3 = w3_base - w3.middleware_onion.add( - construct_error_generator_middleware( - { - RPCEndpoint("fake_endpoint"): lambda *_: f"msg-{uuid.uuid4()}", - } - ) - ) - - w3.middleware_onion.add( - construct_simple_cache_middleware( - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - ) - - with pytest.raises(ValueError) as err_a: - w3.manager.request_blocking("fake_endpoint", []) - with pytest.raises(ValueError) as err_b: - w3.manager.request_blocking("fake_endpoint", []) - - assert str(err_a) != str(err_b) - - -def test_simple_cache_middleware_does_not_cache_endpoints_not_in_whitelist(w3): - w3.middleware_onion.add( - construct_simple_cache_middleware( - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - ) - - result_a = w3.manager.request_blocking("not_whitelisted", []) - result_b = w3.manager.request_blocking("not_whitelisted", []) - - assert result_a != result_b - - -def test_simple_cache_middleware_does_not_share_state_between_providers(): - result_generator_a = construct_result_generator_middleware( - {RPCEndpoint("eth_chainId"): lambda *_: 11111} - ) - result_generator_b = construct_result_generator_middleware( - {RPCEndpoint("eth_chainId"): lambda *_: 22222} - ) - result_generator_c = construct_result_generator_middleware( - {RPCEndpoint("eth_chainId"): lambda *_: 33333} - ) - - w3_a = Web3(provider=BaseProvider(), middlewares=[result_generator_a]) - w3_b = Web3(provider=BaseProvider(), middlewares=[result_generator_b]) - w3_c = Web3( # instantiate the Web3 instance with the cache middleware - provider=BaseProvider(), - middlewares=[ - result_generator_c, - simple_cache_middleware, - ], - ) - - w3_a.middleware_onion.add(simple_cache_middleware) - w3_b.middleware_onion.add(simple_cache_middleware) - - result_a = w3_a.manager.request_blocking("eth_chainId", []) - result_b = w3_b.manager.request_blocking("eth_chainId", []) - result_c = w3_c.manager.request_blocking("eth_chainId", []) - - assert result_a != result_b != result_c - assert result_a == 11111 - assert result_b == 22222 - assert result_c == 33333 - - -# -- async -- # - - -async def _async_simple_cache_middleware_for_testing(make_request, async_w3): - middleware = await async_construct_simple_cache_middleware( - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - return await middleware(make_request, async_w3) - - -@pytest.fixture -def async_w3(): - return AsyncWeb3( - provider=AsyncEthereumTesterProvider(), - middlewares=[ - (_async_simple_cache_middleware_for_testing, "simple_cache"), - ], - ) - - -@pytest.mark.asyncio -async def test_async_simple_cache_middleware_pulls_from_cache(async_w3): - async def _properly_awaited_middleware(make_request, _async_w3): - middleware = await async_construct_simple_cache_middleware( - cache=simple_cache_return_value_a(), - rpc_whitelist={RPCEndpoint("fake_endpoint")}, - ) - return await middleware(make_request, _async_w3) - - async_w3.middleware_onion.inject( - _properly_awaited_middleware, - layer=0, - ) - - _result = await async_w3.manager.coro_request("fake_endpoint", [1]) - assert _result == "value-a" - - -@pytest.mark.asyncio -async def test_async_simple_cache_middleware_populates_cache(async_w3): - async_w3.middleware_onion.inject( - await async_construct_result_generator_middleware( - { - RPCEndpoint("fake_endpoint"): lambda *_: str(uuid.uuid4()), - } - ), - "result_generator", - layer=0, - ) - - result = await async_w3.manager.coro_request("fake_endpoint", []) - - _empty_params = await async_w3.manager.coro_request("fake_endpoint", []) - _non_empty_params = await async_w3.manager.coro_request("fake_endpoint", [1]) - - assert _empty_params == result - assert _non_empty_params != result - - -@pytest.mark.asyncio -async def test_async_simple_cache_middleware_does_not_cache_none_responses(async_w3): - counter = itertools.count() - - def result_cb(_method, _params): - next(counter) - return None - - async_w3.middleware_onion.inject( - await async_construct_result_generator_middleware( - { - RPCEndpoint("fake_endpoint"): result_cb, - }, - ), - "result_generator", - layer=0, - ) - - await async_w3.manager.coro_request("fake_endpoint", []) - await async_w3.manager.coro_request("fake_endpoint", []) - - assert next(counter) == 2 - - -@pytest.mark.asyncio -async def test_async_simple_cache_middleware_does_not_cache_error_responses(async_w3): - async_w3.middleware_onion.inject( - await async_construct_error_generator_middleware( - { - RPCEndpoint("fake_endpoint"): lambda *_: f"msg-{uuid.uuid4()}", - } - ), - "error_generator", - layer=0, - ) - - with pytest.raises(ValueError) as err_a: - await async_w3.manager.coro_request("fake_endpoint", []) - with pytest.raises(ValueError) as err_b: - await async_w3.manager.coro_request("fake_endpoint", []) - - assert str(err_a) != str(err_b) - - -@pytest.mark.asyncio -async def test_async_simple_cache_middleware_does_not_cache_non_whitelist_endpoints( - async_w3, -): - async_w3.middleware_onion.inject( - await async_construct_result_generator_middleware( - { - RPCEndpoint("not_whitelisted"): lambda *_: str(uuid.uuid4()), - } - ), - layer=0, - ) - - result_a = await async_w3.manager.coro_request("not_whitelisted", []) - result_b = await async_w3.manager.coro_request("not_whitelisted", []) - - assert result_a != result_b - - -@pytest.mark.asyncio -async def test_async_simple_cache_middleware_does_not_share_state_between_providers(): - result_generator_a = await async_construct_result_generator_middleware( - {RPCEndpoint("eth_chainId"): lambda *_: 11111} - ) - result_generator_b = await async_construct_result_generator_middleware( - {RPCEndpoint("eth_chainId"): lambda *_: 22222} - ) - result_generator_c = await async_construct_result_generator_middleware( - {RPCEndpoint("eth_chainId"): lambda *_: 33333} - ) - - w3_a = AsyncWeb3( - provider=AsyncEthereumTesterProvider(), middlewares=[result_generator_a] - ) - w3_b = AsyncWeb3( - provider=AsyncEthereumTesterProvider(), middlewares=[result_generator_b] - ) - w3_c = AsyncWeb3( # instantiate the Web3 instance with the cache middleware - provider=AsyncEthereumTesterProvider(), - middlewares=[ - result_generator_c, - async_simple_cache_middleware, - ], - ) - - w3_a.middleware_onion.add(async_simple_cache_middleware) - w3_b.middleware_onion.add(async_simple_cache_middleware) - - result_a = await w3_a.eth.chain_id - result_b = await w3_b.eth.chain_id - result_c = await w3_c.eth.chain_id - - assert result_a != result_b != result_c - assert result_a == 11111 - assert result_b == 22222 - assert result_c == 33333 diff --git a/tests/core/middleware/test_time_based_cache_middleware.py b/tests/core/middleware/test_time_based_cache_middleware.py deleted file mode 100644 index fa861b9948..0000000000 --- a/tests/core/middleware/test_time_based_cache_middleware.py +++ /dev/null @@ -1,168 +0,0 @@ -import itertools -import pytest -import time -import uuid - -from web3 import ( - Web3, -) -from web3._utils.caching import ( - generate_cache_key, -) -from web3.middleware import ( - construct_error_generator_middleware, - construct_result_generator_middleware, - construct_time_based_cache_middleware, -) -from web3.providers.base import ( - BaseProvider, -) - - -@pytest.fixture -def w3_base(): - return Web3(provider=BaseProvider(), middlewares=[]) - - -@pytest.fixture -def result_generator_middleware(): - return construct_result_generator_middleware( - { - "fake_endpoint": lambda *_: str(uuid.uuid4()), - "not_whitelisted": lambda *_: str(uuid.uuid4()), - } - ) - - -@pytest.fixture -def time_cache_middleware(): - return construct_time_based_cache_middleware( - cache_class=dict, - cache_expire_seconds=10, - rpc_whitelist={"fake_endpoint"}, - ) - - -@pytest.fixture -def w3(w3_base, result_generator_middleware, time_cache_middleware): - w3_base.middleware_onion.add(result_generator_middleware) - w3_base.middleware_onion.add(time_cache_middleware) - return w3_base - - -def test_time_based_cache_middleware_pulls_from_cache(w3_base): - w3 = w3_base - - def cache_class(): - return { - generate_cache_key(("fake_endpoint", [1])): ( - time.time(), - {"result": "value-a"}, - ), - } - - w3.middleware_onion.add( - construct_time_based_cache_middleware( - cache_class=cache_class, - cache_expire_seconds=10, - rpc_whitelist={"fake_endpoint"}, - ) - ) - - assert w3.manager.request_blocking("fake_endpoint", [1]) == "value-a" - - -def test_time_based_cache_middleware_populates_cache(w3): - result = w3.manager.request_blocking("fake_endpoint", []) - - assert w3.manager.request_blocking("fake_endpoint", []) == result - assert w3.manager.request_blocking("fake_endpoint", [1]) != result - - -def test_time_based_cache_middleware_expires_old_values( - w3_base, result_generator_middleware -): - w3 = w3_base - w3.middleware_onion.add(result_generator_middleware) - - def cache_class(): - return { - generate_cache_key(("fake_endpoint", [1])): ( - time.time() - 10, - {"result": "value-a"}, - ), - } - - w3.middleware_onion.add( - construct_time_based_cache_middleware( - cache_class=cache_class, - cache_expire_seconds=10, - rpc_whitelist={"fake_endpoint"}, - ) - ) - - result = w3.manager.request_blocking("fake_endpoint", [1]) - assert result != "value-a" - assert w3.manager.request_blocking("fake_endpoint", [1]) == result - - -@pytest.mark.parametrize( - "response", - ( - {}, - {"result": None}, - ), -) -def test_time_based_cache_middleware_does_not_cache_bad_responses( - w3_base, response, time_cache_middleware -): - w3 = w3_base - counter = itertools.count() - - def mk_result(method, params): - next(counter) - return None - - w3.middleware_onion.add( - construct_result_generator_middleware({"fake_endpoint": mk_result}) - ) - w3.middleware_onion.add(time_cache_middleware) - - w3.manager.request_blocking("fake_endpoint", []) - w3.manager.request_blocking("fake_endpoint", []) - - assert next(counter) == 2 - - -def test_time_based_cache_middleware_does_not_cache_error_response( - w3_base, time_cache_middleware -): - w3 = w3_base - counter = itertools.count() - - def mk_error(method, params): - return f"error-number-{next(counter)}" - - w3.middleware_onion.add( - construct_error_generator_middleware( - { - "fake_endpoint": mk_error, - } - ) - ) - w3.middleware_onion.add(time_cache_middleware) - - with pytest.raises(ValueError) as err: - w3.manager.request_blocking("fake_endpoint", []) - assert "error-number-0" in str(err) - - with pytest.raises(ValueError) as err: - w3.manager.request_blocking("fake_endpoint", []) - assert "error-number-1" in str(err) - - -def test_time_based_cache_middleware_does_not_cache_endpoints_not_in_whitelist(w3): - result_a = w3.manager.request_blocking("not_whitelisted", []) - result_b = w3.manager.request_blocking("not_whitelisted", []) - - assert result_a != result_b diff --git a/web3/middleware/exception_handling.py b/web3/middleware/exception_handling.py deleted file mode 100644 index d264c7eaf7..0000000000 --- a/web3/middleware/exception_handling.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Optional, - Tuple, - Type, -) - -from eth_utils.toolz import ( - excepts, -) - -from web3.types import ( - Middleware, - RPCEndpoint, - RPCResponse, -) - -if TYPE_CHECKING: - from web3 import Web3 # noqa: F401 - - -def construct_exception_handler_middleware( - method_handlers: Optional[ - Dict[RPCEndpoint, Tuple[Type[BaseException], Callable[..., None]]] - ] = None -) -> Middleware: - if method_handlers is None: - method_handlers = {} - - def exception_handler_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in method_handlers: - exc_type, handler = method_handlers[method] - return excepts( - exc_type, - make_request, - handler, - )(method, params) - else: - return make_request(method, params) - - return middleware - - return exception_handler_middleware diff --git a/web3/middleware/fixture.py b/web3/middleware/fixture.py deleted file mode 100644 index 6c87ee8f9a..0000000000 --- a/web3/middleware/fixture.py +++ /dev/null @@ -1,190 +0,0 @@ -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Dict, - Optional, - cast, -) - -from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareCoroutine, - Middleware, - RPCEndpoint, - RPCResponse, -) - -if TYPE_CHECKING: - from web3.main import ( # noqa: F401 - AsyncWeb3, - Web3, - ) - from web3.providers import ( # noqa: F401 - PersistentConnectionProvider, - ) - - -def construct_fixture_middleware(fixtures: Dict[RPCEndpoint, Any]) -> Middleware: - """ - Constructs a middleware which returns a static response for any method - which is found in the provided fixtures. - """ - - def fixture_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], _: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in fixtures: - result = fixtures[method] - return {"result": result} - else: - return make_request(method, params) - - return middleware - - return fixture_middleware - - -def construct_result_generator_middleware( - result_generators: Dict[RPCEndpoint, Any] -) -> Middleware: - """ - Constructs a middleware which intercepts requests for any method found in - the provided mapping of endpoints to generator functions, returning - whatever response the generator function returns. Callbacks must be - functions with the signature `fn(method, params)`. - """ - - def result_generator_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], _: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in result_generators: - result = result_generators[method](method, params) - return {"result": result} - else: - return make_request(method, params) - - return middleware - - return result_generator_middleware - - -def construct_error_generator_middleware( - error_generators: Dict[RPCEndpoint, Any] -) -> Middleware: - """ - Constructs a middleware which intercepts requests for any method found in - the provided mapping of endpoints to generator functions, returning - whatever error message the generator function returns. Callbacks must be - functions with the signature `fn(method, params)`. - """ - - def error_generator_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], _: "Web3" - ) -> Callable[[RPCEndpoint, Any], RPCResponse]: - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in error_generators: - error = error_generators[method](method, params) - if isinstance(error, dict) and error.get("error", False): - return { - "error": { - "code": error.get("code", -32000), - "message": error["error"].get("message", ""), - "data": error.get("data", ""), - } - } - else: - return {"error": error} - else: - return make_request(method, params) - - return middleware - - return error_generator_middleware - - -# --- async --- # - - -async def async_construct_result_generator_middleware( - result_generators: Dict[RPCEndpoint, Any] -) -> AsyncMiddleware: - """ - Constructs a middleware which returns a static response for any method - which is found in the provided fixtures. - """ - - async def result_generator_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" - ) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: - if method in result_generators: - result = result_generators[method](method, params) - - if async_w3.provider.has_persistent_connection: - provider = cast("PersistentConnectionProvider", async_w3.provider) - response = await make_request(method, params) - provider._request_processor.append_middleware_response_processor( - # processed asynchronously later but need to pass the actual - # response to the next middleware - response, - lambda _: {"result": result}, - ) - return response - else: - return {"result": result} - else: - return await make_request(method, params) - - return middleware - - return result_generator_middleware - - -async def async_construct_error_generator_middleware( - error_generators: Dict[RPCEndpoint, Any] -) -> AsyncMiddleware: - """ - Constructs a middleware which intercepts requests for any method found in - the provided mapping of endpoints to generator functions, returning - whatever error message the generator function returns. Callbacks must be - functions with the signature `fn(method, params)`. - """ - - async def error_generator_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "AsyncWeb3" - ) -> AsyncMiddlewareCoroutine: - async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: - if method in error_generators: - error = error_generators[method](method, params) - if isinstance(error, dict) and error.get("error", False): - error_response = { - "error": { - "code": error.get("code", -32000), - "message": error["error"].get("message", ""), - "data": error.get("data", ""), - } - } - else: - error_response = {"error": error} - - if async_w3.provider.has_persistent_connection: - provider = cast("PersistentConnectionProvider", async_w3.provider) - response = await make_request(method, params) - provider._request_processor.append_middleware_response_processor( - # processed asynchronously later but need to pass the actual - # response to the next middleware - response, - lambda _: error_response, - ) - return response - else: - return cast(RPCResponse, error_response) - else: - return await make_request(method, params) - - return middleware - - return error_generator_middleware diff --git a/web3/middleware/simulate_unmined_transaction.py b/web3/middleware/simulate_unmined_transaction.py deleted file mode 100644 index 18e2f8bf2c..0000000000 --- a/web3/middleware/simulate_unmined_transaction.py +++ /dev/null @@ -1,43 +0,0 @@ -import collections -import itertools -from typing import ( - Any, - Callable, -) - -from eth_typing import ( - Hash32, -) - -from web3 import ( - Web3, -) -from web3.types import ( - RPCEndpoint, - RPCResponse, - TxReceipt, -) - -counter = itertools.count() - -INVOCATIONS_BEFORE_RESULT = 5 - - -def unmined_receipt_simulator_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: Web3 -) -> Callable[[RPCEndpoint, Any], RPCResponse]: - receipt_counters: DefaultDict[Hash32, TxReceipt] = collections.defaultdict( # type: ignore # noqa: F821, E501 - itertools.count - ) - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method == "eth_getTransactionReceipt": - txn_hash = params[0] - if next(receipt_counters[txn_hash]) < INVOCATIONS_BEFORE_RESULT: - return {"result": None} - else: - return make_request(method, params) - else: - return make_request(method, params) - - return middleware From af912fadf60dbb411c646d282e0fdfcb94c2d917 Mon Sep 17 00:00:00 2001 From: fselmo Date: Tue, 21 Nov 2023 22:29:48 -0700 Subject: [PATCH 05/24] Add request mocker fixture for mocking requests - Add a ``RequestMocker`` class with ``request_mocker`` pytest fixture to fascilitate request mocking now that the middleware doesn't sequester the ``make_request()`` call. Since we can't swap the ``make_request()`` on the fly, mock results and errors using this test utility class. - Linting - Some cleanup along the way --- ens/utils.py | 1 - tests/conftest.py | 15 ++ tests/core/eth-module/test_block_api.py | 28 +-- tests/core/eth-module/test_poa.py | 80 ++++---- tests/core/utilities/test_fee_utils.py | 36 +--- .../go_ethereum/test_goethereum_http.py | 2 +- web3/_utils/caching.py | 1 - web3/_utils/module_testing/eth_module.py | 173 ++++++------------ .../persistent_connection_provider.py | 6 +- web3/_utils/module_testing/utils.py | 113 ++++++++++++ web3/manager.py | 2 +- web3/middleware/abi.py | 1 - web3/middleware/attrdict.py | 4 +- web3/middleware/formatting.py | 25 ++- web3/middleware/names.py | 6 +- web3/middleware/proof_of_authority.py | 4 +- web3/middleware/validation.py | 1 - web3/providers/async_base.py | 9 +- web3/providers/base.py | 5 +- web3/providers/rpc/async_rpc.py | 4 +- web3/providers/rpc/rpc.py | 8 +- web3/providers/rpc/utils.py | 16 +- 22 files changed, 288 insertions(+), 252 deletions(-) create mode 100644 web3/_utils/module_testing/utils.py diff --git a/ens/utils.py b/ens/utils.py index d2da542f0d..7b7aa33dcf 100644 --- a/ens/utils.py +++ b/ens/utils.py @@ -5,7 +5,6 @@ from typing import ( TYPE_CHECKING, Any, - Callable, Collection, Dict, List, diff --git a/tests/conftest.py b/tests/conftest.py index 9e194dbcf8..6a4b502867 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,7 @@ import pytest +from typing import ( + Type, +) from eth_utils import ( event_signature_to_log_topic, @@ -7,10 +10,14 @@ from eth_utils.toolz import ( identity, ) +import pytest_asyncio from web3._utils.contract_sources.contract_data.emitter_contract import ( EMITTER_CONTRACT_DATA, ) +from web3._utils.module_testing.utils import ( + RequestMocker, +) from .utils import ( get_open_port, @@ -107,3 +114,11 @@ class LogTopics: @pytest.fixture(scope="session") def emitter_contract_log_topics(): return LogTopics + + +# -- mock requests -- # + + +@pytest_asyncio.fixture(scope="function") +def request_mocker() -> Type[RequestMocker]: + return RequestMocker diff --git a/tests/core/eth-module/test_block_api.py b/tests/core/eth-module/test_block_api.py index b9abae1329..3b64161fac 100644 --- a/tests/core/eth-module/test_block_api.py +++ b/tests/core/eth-module/test_block_api.py @@ -1,4 +1,7 @@ import pytest +from unittest.mock import ( + Mock, +) from eth_utils import ( to_checksum_address, @@ -7,13 +10,6 @@ HexBytes, ) -from web3._utils.rpc_abi import ( - RPC, -) -from web3.middleware import ( - construct_result_generator_middleware, -) - @pytest.fixture(autouse=True) def wait_for_first_block(w3, wait_for_block): @@ -26,7 +22,7 @@ def test_uses_default_block(w3, extra_accounts, wait_for_transaction): assert w3.eth.default_block == w3.eth.block_number -def test_get_block_formatters_with_null_values(w3): +def test_get_block_formatters_with_null_values(w3, mocker): null_values_block = { "baseFeePerGas": None, "extraData": None, @@ -51,13 +47,7 @@ def test_get_block_formatters_with_null_values(w3): "withdrawalsRoot": None, "withdrawals": [], } - result_middleware = construct_result_generator_middleware( - { - RPC.eth_getBlockByNumber: lambda *_: null_values_block, - } - ) - - w3.middleware_onion.inject(result_middleware, "result_middleware", layer=0) + mocker.patch("web3.eth.eth.Eth.get_block", return_value=null_values_block) received_block = w3.eth.get_block("pending") assert received_block == null_values_block @@ -116,13 +106,9 @@ def test_get_block_formatters_with_pre_formatted_values(w3): }, ], } - result_middleware = construct_result_generator_middleware( - { - RPC.eth_getBlockByNumber: lambda *_: unformatted_values_block, - } - ) - w3.middleware_onion.inject(result_middleware, "result_middleware", layer=0) + w3.manager._make_request = Mock() + w3.manager._make_request.return_value = {"result": unformatted_values_block} received_block = w3.eth.get_block("pending") diff --git a/tests/core/eth-module/test_poa.py b/tests/core/eth-module/test_poa.py index a9481b3245..6e0c901458 100644 --- a/tests/core/eth-module/test_poa.py +++ b/tests/core/eth-module/test_poa.py @@ -5,54 +5,42 @@ ExtraDataLengthError, ) from web3.middleware import ( - construct_fixture_middleware, - extradata_to_poa, + extradata_to_poa_middleware, ) # In the spec, a block with extra data longer than 32 bytes is invalid -def test_long_extra_data(w3): - return_block_with_long_extra_data = construct_fixture_middleware( - { - "eth_getBlockByNumber": {"extraData": "0x" + "ff" * 33}, - } - ) - w3.middleware_onion.inject(return_block_with_long_extra_data, layer=0) - with pytest.raises(ExtraDataLengthError): - w3.eth.get_block("latest") - - -def test_full_extra_data(w3): - return_block_with_long_extra_data = construct_fixture_middleware( - { - "eth_getBlockByNumber": {"extraData": "0x" + "ff" * 32}, - } - ) - w3.middleware_onion.inject(return_block_with_long_extra_data, layer=0) - block = w3.eth.get_block("latest") - assert block.extraData == b"\xff" * 32 - - -def test_geth_proof_of_authority(w3): - return_block_with_long_extra_data = construct_fixture_middleware( - { +def test_long_extra_data(w3, request_mocker): + with request_mocker( + w3, mock_results={"eth_getBlockByNumber": {"extraData": "0x" + "ff" * 33}} + ): + with pytest.raises(ExtraDataLengthError): + w3.eth.get_block("latest") + + +def test_full_extra_data(w3, request_mocker): + with request_mocker( + w3, mock_results={"eth_getBlockByNumber": {"extraData": "0x" + "ff" * 32}} + ): + block = w3.eth.get_block("latest") + assert block.extraData == b"\xff" * 32 + + +def test_geth_proof_of_authority(w3, request_mocker): + w3.middleware_onion.inject(extradata_to_poa_middleware, layer=0) + + with request_mocker( + w3, + mock_results={ "eth_getBlockByNumber": {"extraData": "0x" + "ff" * 33}, - } - ) - w3.middleware_onion.inject(extradata_to_poa, layer=0) - w3.middleware_onion.inject(return_block_with_long_extra_data, layer=0) - block = w3.eth.get_block("latest") - assert "extraData" not in block - assert block.proofOfAuthorityData == b"\xff" * 33 - - -def test_returns_none_response(w3): - return_none_response = construct_fixture_middleware( - { - "eth_getBlockByNumber": None, - } - ) - w3.middleware_onion.inject(extradata_to_poa, layer=0) - w3.middleware_onion.inject(return_none_response, layer=0) - with pytest.raises(BlockNotFound): - w3.eth.get_block(100000000000) + }, + ): + block = w3.eth.get_block("latest") + assert "extraData" not in block + assert block.proofOfAuthorityData == b"\xff" * 33 + + +def test_returns_none_response(w3, request_mocker): + with request_mocker(w3, mock_results={"eth_getBlockByNumber": None}): + with pytest.raises(BlockNotFound): + w3.eth.get_block("latest") diff --git a/tests/core/utilities/test_fee_utils.py b/tests/core/utilities/test_fee_utils.py index 84a1fb743d..a9bbdb1639 100644 --- a/tests/core/utilities/test_fee_utils.py +++ b/tests/core/utilities/test_fee_utils.py @@ -4,14 +4,6 @@ is_integer, ) -from web3.middleware import ( - construct_error_generator_middleware, - construct_result_generator_middleware, -) -from web3.types import ( - RPCEndpoint, -) - @pytest.mark.parametrize( "fee_history_rewards,expected_max_prio_calc", @@ -37,28 +29,20 @@ ) # Test fee_utils indirectly by mocking eth_feeHistory results # and checking against expected output -def test_fee_utils_indirectly(w3, fee_history_rewards, expected_max_prio_calc) -> None: - fail_max_prio_middleware = construct_error_generator_middleware( - {RPCEndpoint("eth_maxPriorityFeePerGas"): lambda *_: ""} - ) - fee_history_result_middleware = construct_result_generator_middleware( - {RPCEndpoint("eth_feeHistory"): lambda *_: {"reward": fee_history_rewards}} - ) - - w3.middleware_onion.add(fail_max_prio_middleware, "fail_max_prio") - w3.middleware_onion.inject( - fee_history_result_middleware, "fee_history_result", layer=0 - ) - +def test_fee_utils_indirectly( + w3, fee_history_rewards, expected_max_prio_calc, request_mocker +) -> None: with pytest.warns( UserWarning, match="There was an issue with the method eth_maxPriorityFeePerGas. " "Calculating using eth_feeHistory.", ): - max_priority_fee = w3.eth.max_priority_fee + with request_mocker( + w3, + mock_errors={"eth_maxPriorityFeePerGas": {}}, + mock_results={"eth_feeHistory": {"reward": fee_history_rewards}}, + ): + max_priority_fee = w3.eth.max_priority_fee + assert is_integer(max_priority_fee) assert max_priority_fee == expected_max_prio_calc - - # clean up - w3.middleware_onion.remove("fail_max_prio") - w3.middleware_onion.remove("fee_history_result") diff --git a/tests/integration/go_ethereum/test_goethereum_http.py b/tests/integration/go_ethereum/test_goethereum_http.py index 5013e7cf80..9963854437 100644 --- a/tests/integration/go_ethereum/test_goethereum_http.py +++ b/tests/integration/go_ethereum/test_goethereum_http.py @@ -18,7 +18,7 @@ from web3._utils.module_testing.go_ethereum_personal_module import ( GoEthereumAsyncPersonalModuleTest, ) -from web3.providers.async_rpc import ( +from web3.providers.rpc import ( AsyncHTTPProvider, ) diff --git a/web3/_utils/caching.py b/web3/_utils/caching.py index 7af5235f1c..e047acef8f 100644 --- a/web3/_utils/caching.py +++ b/web3/_utils/caching.py @@ -22,7 +22,6 @@ to_bytes, ) - if TYPE_CHECKING: from web3.providers import ( # noqa: F401 BaseProvider, diff --git a/web3/_utils/module_testing/eth_module.py b/web3/_utils/module_testing/eth_module.py index 5f479e060f..34659e954e 100644 --- a/web3/_utils/module_testing/eth_module.py +++ b/web3/_utils/module_testing/eth_module.py @@ -10,6 +10,7 @@ Any, Callable, List, + Type, Union, cast, ) @@ -49,6 +50,9 @@ from web3._utils.ens import ( ens_addresses, ) +from web3._utils.fee_utils import ( + PRIORITY_FEE_MIN, +) from web3._utils.method_formatters import ( to_hex_if_integer, ) @@ -58,6 +62,9 @@ mine_pending_block, mock_offchain_lookup_request_response, ) +from web3._utils.module_testing.utils import ( + RequestMocker, +) from web3._utils.type_conversion import ( to_hex_if_bytes, ) @@ -78,12 +85,7 @@ Web3ValidationError, ) from web3.middleware import ( - async_geth_poa_middleware, -) -from web3.middleware.fixture import ( - async_construct_error_generator_middleware, - async_construct_result_generator_middleware, - construct_error_generator_middleware, + extradata_to_poa_middleware, ) from web3.types import ( ENS, @@ -648,27 +650,23 @@ async def test_validation_middleware_chain_id_mismatch( await async_w3.eth.send_transaction(txn_params) @pytest.mark.asyncio - async def test_geth_poa_middleware(self, async_w3: "AsyncWeb3") -> None: - return_block_with_long_extra_data = ( - await async_construct_result_generator_middleware( - { - RPCEndpoint("eth_getBlockByNumber"): lambda *_: { - "extraData": "0x" + "ff" * 33 - }, - } - ) - ) - async_w3.middleware_onion.inject(async_geth_poa_middleware, "poa", layer=0) - async_w3.middleware_onion.inject( - return_block_with_long_extra_data, "extradata", layer=0 - ) - block = await async_w3.eth.get_block("latest") + async def test_geth_poa_middleware( + self, async_w3: "AsyncWeb3", request_mocker + ) -> None: + async_w3.middleware_onion.inject(extradata_to_poa_middleware, "poa", layer=0) + extra_data = f"0x{'ff' * 33}" + + async with request_mocker( + async_w3, + mock_results={"eth_getBlockByNumber": {"extraData": extra_data}}, + ): + block = await async_w3.eth.get_block("latest") + assert "extraData" not in block - assert block["proofOfAuthorityData"] == b"\xff" * 33 + assert block["proofOfAuthorityData"] == to_bytes(hexstr=extra_data) # clean up async_w3.middleware_onion.remove("poa") - async_w3.middleware_onion.remove("extradata") @pytest.mark.asyncio async def test_eth_send_raw_transaction(self, async_w3: "AsyncWeb3") -> None: @@ -852,60 +850,25 @@ async def test_eth_max_priority_fee(self, async_w3: "AsyncWeb3") -> None: max_priority_fee = await async_w3.eth.max_priority_fee assert is_integer(max_priority_fee) - @pytest.mark.asyncio - async def test_eth_max_priority_fee_with_fee_history_calculation_error_dict( - self, async_w3: "AsyncWeb3" - ) -> None: - fail_max_prio_middleware = await async_construct_error_generator_middleware( - { - RPCEndpoint("eth_maxPriorityFeePerGas"): lambda *_: { - "error": { - "code": -32601, - "message": ( - "The method eth_maxPriorityFeePerGas does " - "not exist/is not available" - ), - } - } - } - ) - async_w3.middleware_onion.add( - fail_max_prio_middleware, name="fail_max_prio_middleware" - ) - - with pytest.warns( - UserWarning, - match=( - "There was an issue with the method eth_maxPriorityFeePerGas." - " Calculating using eth_feeHistory." - ), - ): - await async_w3.eth.max_priority_fee - - async_w3.middleware_onion.remove("fail_max_prio_middleware") # clean up - @pytest.mark.asyncio async def test_eth_max_priority_fee_with_fee_history_calculation( - self, async_w3: "AsyncWeb3" + self, async_w3: "AsyncWeb3", request_mocker: Type[RequestMocker] ) -> None: - fail_max_prio_middleware = await async_construct_error_generator_middleware( - {RPCEndpoint("eth_maxPriorityFeePerGas"): lambda *_: ""} - ) - async_w3.middleware_onion.add( - fail_max_prio_middleware, name="fail_max_prio_middleware" - ) - - with pytest.warns( - UserWarning, - match=( - "There was an issue with the method eth_maxPriorityFeePerGas. " - "Calculating using eth_feeHistory." - ), + async with request_mocker( + async_w3, + mock_errors={RPCEndpoint("eth_maxPriorityFeePerGas"): {}}, + mock_results={RPCEndpoint("eth_feeHistory"): {"reward": [[0]]}}, ): - max_priority_fee = await async_w3.eth.max_priority_fee - assert is_integer(max_priority_fee) - - async_w3.middleware_onion.remove("fail_max_prio_middleware") # clean up + with pytest.warns( + UserWarning, + match=( + "There was an issue with the method eth_maxPriorityFeePerGas. " + "Calculating using eth_feeHistory." + ), + ): + priority_fee = await async_w3.eth.max_priority_fee + assert is_integer(priority_fee) + assert priority_fee == PRIORITY_FEE_MIN @pytest.mark.asyncio async def test_eth_getBlockByHash( @@ -2480,58 +2443,24 @@ def test_eth_max_priority_fee(self, w3: "Web3") -> None: max_priority_fee = w3.eth.max_priority_fee assert is_integer(max_priority_fee) - def test_eth_max_priority_fee_with_fee_history_calculation_error_dict( - self, w3: "Web3" - ) -> None: - fail_max_prio_middleware = construct_error_generator_middleware( - { - RPCEndpoint("eth_maxPriorityFeePerGas"): lambda *_: { - "error": { - "code": -32601, - "message": ( - "The method eth_maxPriorityFeePerGas does " - "not exist/is not available" - ), - } - } - } - ) - w3.middleware_onion.add( - fail_max_prio_middleware, name="fail_max_prio_middleware" - ) - - with pytest.warns( - UserWarning, - match=( - "There was an issue with the method eth_maxPriorityFeePerGas." - " Calculating using eth_feeHistory." - ), - ): - w3.eth.max_priority_fee - - w3.middleware_onion.remove("fail_max_prio_middleware") # clean up - def test_eth_max_priority_fee_with_fee_history_calculation( - self, w3: "Web3" + self, w3: "Web3", request_mocker: Type[RequestMocker] ) -> None: - fail_max_prio_middleware = construct_error_generator_middleware( - {RPCEndpoint("eth_maxPriorityFeePerGas"): lambda *_: ""} - ) - w3.middleware_onion.add( - fail_max_prio_middleware, name="fail_max_prio_middleware" - ) - - with pytest.warns( - UserWarning, - match=( - "There was an issue with the method eth_maxPriorityFeePerGas." - " Calculating using eth_feeHistory." - ), + with request_mocker( + w3, + mock_errors={RPCEndpoint("eth_maxPriorityFeePerGas"): {}}, + mock_results={RPCEndpoint("eth_feeHistory"): {"reward": [[0]]}}, ): - max_priority_fee = w3.eth.max_priority_fee - assert is_integer(max_priority_fee) - - w3.middleware_onion.remove("fail_max_prio_middleware") # clean up + with pytest.warns( + UserWarning, + match=( + "There was an issue with the method eth_maxPriorityFeePerGas. " + "Calculating using eth_feeHistory." + ), + ): + max_priority_fee = w3.eth.max_priority_fee + assert is_integer(max_priority_fee) + assert max_priority_fee == PRIORITY_FEE_MIN def test_eth_accounts(self, w3: "Web3") -> None: accounts = w3.eth.accounts diff --git a/web3/_utils/module_testing/persistent_connection_provider.py b/web3/_utils/module_testing/persistent_connection_provider.py index b07416c029..7cf314122c 100644 --- a/web3/_utils/module_testing/persistent_connection_provider.py +++ b/web3/_utils/module_testing/persistent_connection_provider.py @@ -21,7 +21,7 @@ AttributeDict, ) from web3.middleware import ( - async_geth_poa_middleware, + extradata_to_poa_middleware, ) from web3.types import ( FormattedEthSubscriptionResponse, @@ -34,7 +34,7 @@ def _mocked_recv(sub_id: str, ws_subscription_response: Dict[str, Any]) -> bytes: - # Must be same subscription id so we can know how to parse the message. + # Must be same subscription id, so we can know how to parse the message. # We don't have this information when mocking the response. ws_subscription_response["params"]["subscription"] = sub_id return to_bytes(text=json.dumps(ws_subscription_response)) @@ -321,7 +321,7 @@ async def test_async_geth_poa_middleware_on_eth_subscription( async_w3: "_PersistentConnectionWeb3", ) -> None: async_w3.middleware_onion.inject( - async_geth_poa_middleware, "poa_middleware", layer=0 + extradata_to_poa_middleware, "poa_middleware", layer=0 ) sub_id = await async_w3.eth.subscribe("newHeads") diff --git a/web3/_utils/module_testing/utils.py b/web3/_utils/module_testing/utils.py new file mode 100644 index 0000000000..9dbecb1734 --- /dev/null +++ b/web3/_utils/module_testing/utils.py @@ -0,0 +1,113 @@ +import copy +from typing import ( + Any, + Dict, + TypeVar, +) + +from toolz import ( + merge, +) + +from web3 import ( + AsyncWeb3, + Web3, +) +from web3.exceptions import ( + Web3ValidationError, +) +from web3.types import ( + RPCEndpoint, +) + +WEB3 = TypeVar("WEB3", Web3, AsyncWeb3) + + +class RequestMocker: + """ + Context manager to mock requests made by a web3 instance. This is meant to be used + via a ``request_mocker`` fixture defined within the appropriate context. + + Example: + + def test_my_w3(w3, request_mocker): + assert w3.eth.block_number == 0 + + with request_mocker(w3, mock_results={"eth_blockNumber": "0x1"}): + assert w3.eth.block_number == 1 + + assert w3.eth.block_number == 0 + + ``mock_results`` is a dict mapping method names to the desired "result" object of + the RPC response. ``mock_errors`` is a dict mapping method names to the desired + "error" object of the RPC response. If a method name is not in either dict, + the request is made as usual. + """ + + def __init__( + self, + w3: WEB3, + mock_results: Dict[RPCEndpoint, Dict[str, Any]] = None, + mock_errors: Dict[RPCEndpoint, Dict[str, Any]] = None, + ): + self.w3 = w3 + self.mock_results = mock_results or {} + self.mock_errors = mock_errors or {} + self._make_request = w3.provider.make_request + + def __enter__(self): + self.w3.provider.make_request = self._mock_request_handler + # reset request func cache to re-build request_func with mocked make_request + self.w3.provider._request_func_cache = (None, None) + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.w3.provider.make_request = self._make_request + # reset request func cache to re-build request_func with original make_request + self.w3.provider._request_func_cache = (None, None) + + def _mock_request_handler(self, method, params): + if method in self.mock_results: + return {"result": self.mock_results[method]} + elif method in self.mock_errors: + error = self.mock_errors[method] + if not isinstance(error, dict): + raise Web3ValidationError("error must be a dict") + code = error.get("code", -32000) + message = error.get("message", "Mocked error") + return {"error": merge({"code": code, "message": message}, error)} + else: + return self._make_request(method, params) + + # -- async -- # + async def __aenter__(self): + self.w3.provider.make_request = self._async_mock_request_handler + # reset request func cache to re-build request_func with mocked make_request + self.w3.provider._request_func_cache = (None, None) + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + self.w3.provider.make_request = self._make_request + # reset request func cache to re-build request_func with original make_request + self.w3.provider._request_func_cache = (None, None) + + async def _async_mock_request_handler(self, method, params): + if method not in self.mock_errors and method not in self.mock_results: + return await self._make_request(method, params) + + response_dict = { + "jsonrpc": "2.0", + "id": next(copy.deepcopy(self.w3.provider.request_counter)), + } + if method in self.mock_results: + return merge(response_dict, {"result": self.mock_results[method]}) + elif method in self.mock_errors: + error = self.mock_errors[method] + if not isinstance(error, dict): + raise Web3ValidationError("error must be a dict") + code = error.get("code", -32000) + message = error.get("message", "Mocked error") + return merge( + response_dict, + {"error": merge({"code": code, "message": message}, error)}, + ) diff --git a/web3/manager.py b/web3/manager.py index 770c040d40..a7c666a179 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -39,8 +39,8 @@ from web3.middleware import ( attrdict_middleware, buffered_gas_estimate_middleware, - gas_price_strategy_middleware, ens_name_to_address_middleware, + gas_price_strategy_middleware, validation_middleware, ) from web3.module import ( diff --git a/web3/middleware/abi.py b/web3/middleware/abi.py index 65290ce87b..d285893006 100644 --- a/web3/middleware/abi.py +++ b/web3/middleware/abi.py @@ -5,5 +5,4 @@ FormattingMiddleware, ) - abi_middleware = FormattingMiddleware(request_formatters=ABI_REQUEST_FORMATTERS) diff --git a/web3/middleware/attrdict.py b/web3/middleware/attrdict.py index 3c0eddbd19..1a1d4e195a 100644 --- a/web3/middleware/attrdict.py +++ b/web3/middleware/attrdict.py @@ -1,4 +1,6 @@ -from abc import ABC +from abc import ( + ABC, +) from typing import ( TYPE_CHECKING, Any, diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index 8d96a64b84..1d2f50ef9e 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -183,9 +183,22 @@ async def async_response_processor( self.result_formatters = formatters["result_formatters"] self.error_formatters = formatters["error_formatters"] - return _apply_response_formatters( - method, - self.result_formatters, - self.error_formatters, - response, - ) + if self._w3.provider.has_persistent_connection: + # asynchronous response processing + provider = cast("PersistentConnectionProvider", self._w3.provider) + provider._request_processor.append_middleware_response_processor( + response, + _apply_response_formatters( + method, + self.result_formatters, + self.error_formatters, + ), + ) + return response + else: + return _apply_response_formatters( + method, + self.result_formatters, + self.error_formatters, + response, + ) diff --git a/web3/middleware/names.py b/web3/middleware/names.py index 8376933a36..bc53614c91 100644 --- a/web3/middleware/names.py +++ b/web3/middleware/names.py @@ -21,9 +21,6 @@ from web3.types import ( RPCEndpoint, ) -from .base import ( - Web3Middleware, -) from .._utils.abi import ( abi_data_tree, @@ -33,6 +30,9 @@ from .._utils.formatters import ( recursive_map, ) +from .base import ( + Web3Middleware, +) from .formatting import ( FormattingMiddleware, ) diff --git a/web3/middleware/proof_of_authority.py b/web3/middleware/proof_of_authority.py index a655a093fa..d162b295d4 100644 --- a/web3/middleware/proof_of_authority.py +++ b/web3/middleware/proof_of_authority.py @@ -22,7 +22,9 @@ from web3._utils.rpc_abi import ( RPC, ) -from web3.middleware.formatting import FormattingMiddleware +from web3.middleware.formatting import ( + FormattingMiddleware, +) if TYPE_CHECKING: from web3 import ( # noqa: F401 diff --git a/web3/middleware/validation.py b/web3/middleware/validation.py index 99b3660a59..cbd08067db 100644 --- a/web3/middleware/validation.py +++ b/web3/middleware/validation.py @@ -36,7 +36,6 @@ FormattingMiddleware, ) from web3.types import ( - AsyncMiddlewareCoroutine, Formatters, FormattersDict, RPCEndpoint, diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py index 2ba9d796ca..f7f641cec7 100644 --- a/web3/providers/async_base.py +++ b/web3/providers/async_base.py @@ -25,15 +25,18 @@ from web3.middleware import ( async_combine_middlewares, ) -from web3.middleware.base import Web3Middleware +from web3.middleware.base import ( + Web3Middleware, +) from web3.types import ( AsyncMiddlewareOnion, MiddlewareOnion, RPCEndpoint, RPCResponse, ) -from web3.utils import SimpleCache - +from web3.utils import ( + SimpleCache, +) if TYPE_CHECKING: from web3 import ( # noqa: F401 diff --git a/web3/providers/base.py b/web3/providers/base.py index ebd2b19dbe..679cda2e68 100644 --- a/web3/providers/base.py +++ b/web3/providers/base.py @@ -30,8 +30,9 @@ RPCEndpoint, RPCResponse, ) -from web3.utils import SimpleCache - +from web3.utils import ( + SimpleCache, +) if TYPE_CHECKING: from web3 import Web3 # noqa: F401 diff --git a/web3/providers/rpc/async_rpc.py b/web3/providers/rpc/async_rpc.py index 51f974ab04..e2279c72b1 100644 --- a/web3/providers/rpc/async_rpc.py +++ b/web3/providers/rpc/async_rpc.py @@ -32,10 +32,10 @@ RPCEndpoint, RPCResponse, ) + from ..._utils.caching import ( async_handle_request_caching, ) - from ...datastructures import ( NamedElementOnion, ) @@ -43,8 +43,8 @@ AsyncJSONBaseProvider, ) from .utils import ( - check_if_retry_on_failure, ExceptionRetryConfiguration, + check_if_retry_on_failure, ) diff --git a/web3/providers/rpc/rpc.py b/web3/providers/rpc/rpc.py index abd4243b0e..9d7aa05997 100644 --- a/web3/providers/rpc/rpc.py +++ b/web3/providers/rpc/rpc.py @@ -33,15 +33,15 @@ RPCResponse, ) +from ..._utils.caching import ( + handle_request_caching, +) from ..base import ( JSONBaseProvider, ) from .utils import ( - check_if_retry_on_failure, ExceptionRetryConfiguration, -) -from ..._utils.caching import ( - handle_request_caching, + check_if_retry_on_failure, ) diff --git a/web3/providers/rpc/utils.py b/web3/providers/rpc/utils.py index c829448dc2..9d87ad5bc4 100644 --- a/web3/providers/rpc/utils.py +++ b/web3/providers/rpc/utils.py @@ -1,11 +1,16 @@ -import asyncio -from typing import Sequence, Type +from typing import ( + Sequence, + Type, +) +from pydantic import ( + BaseModel, +) import requests -from pydantic import BaseModel - -from web3.types import RPCEndpoint +from web3.types import ( + RPCEndpoint, +) REQUEST_RETRY_ALLOWLIST = [ "admin", @@ -84,7 +89,6 @@ class ExceptionRetryConfiguration(BaseModel): errors: Sequence[Type[BaseException]] = ( ConnectionError, requests.HTTPError, - asyncio.Timeout, requests.Timeout, ) retries: int = 5 From 9c2f959e25bdf6d4debd03a8218f6bfb5367f3cd Mon Sep 17 00:00:00 2001 From: fselmo Date: Wed, 29 Nov 2023 12:34:24 -0700 Subject: [PATCH 06/24] Some refactoring to make cached_requests work with WSV2 --- web3/_utils/caching.py | 16 ++------- web3/module.py | 6 +++- web3/providers/async_base.py | 12 ++++--- web3/providers/base.py | 13 +++++--- web3/providers/websocket/request_processor.py | 33 +++++++++++++++---- web3/providers/websocket/websocket_v2.py | 18 ++-------- 6 files changed, 54 insertions(+), 44 deletions(-) diff --git a/web3/_utils/caching.py b/web3/_utils/caching.py index e047acef8f..b95eae603a 100644 --- a/web3/_utils/caching.py +++ b/web3/_utils/caching.py @@ -67,7 +67,7 @@ def __init__( def is_cacheable_request(provider: "BaseProvider", method: "RPCEndpoint") -> bool: - if provider._cache_allowed_requests and method in provider._cacheable_requests: + if provider.cache_allowed_requests and method in provider.cacheable_requests: return True @@ -101,21 +101,11 @@ async def wrapper(*args, **kwargs): provider, method, params = args if is_cacheable_request(provider, method): - if provider.has_persistent_connection: - next_request_id = provider._effective_request_number + 1 - else: - next_request_id = copy.deepcopy(next(provider.request_counter)) - request_cache = provider._request_cache - cache_key = generate_cache_key((method, params, kwargs)) + cache_key = generate_cache_key((method, params)) cache_result = request_cache.get_cache_entry(cache_key) if cache_result is not None: - # Increment request counter. This both makes the request unique, - # independent of where the response came from, and ensures that - # calculated request and response id matching can happen for - # persistent connection providers. - next(provider.request_counter) - return merge(cache_result, {"id": next_request_id}) + return cache_result else: response = await func(*args, **kwargs) request_cache.cache(cache_key, response) diff --git a/web3/module.py b/web3/module.py index 41ad2ca156..e04cf5ea31 100644 --- a/web3/module.py +++ b/web3/module.py @@ -103,7 +103,11 @@ async def caller(*args: Any, **kwargs: Any) -> Union[RPCResponse, AsyncLogFilter method_str = cast(RPCEndpoint, method_str) return await async_w3.manager.ws_send(method_str, params) except Exception as e: - if cache_key in provider._request_processor._request_information_cache: + if ( + cache_key is not None + and cache_key + in provider._request_processor._request_information_cache + ): provider._request_processor.pop_cached_request_information( cache_key ) diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py index f7f641cec7..4a10b0671c 100644 --- a/web3/providers/async_base.py +++ b/web3/providers/async_base.py @@ -68,15 +68,19 @@ class AsyncBaseProvider: Tuple[Web3Middleware, ...], Callable[..., Coroutine[Any, Any, RPCResponse]] ] = (None, None) - _request_cache: SimpleCache = SimpleCache(size=500) - _cache_allowed_requests: bool = True - _cacheable_requests: Set[RPCEndpoint] = CACHEABLE_REQUESTS - is_async = True has_persistent_connection = False global_ccip_read_enabled: bool = True ccip_read_max_redirects: int = 4 + # request caching + cache_allowed_requests: bool = False + cacheable_requests: Set[RPCEndpoint] = CACHEABLE_REQUESTS + _request_cache: SimpleCache + + def __init__(self) -> None: + self._request_cache = SimpleCache(1000) + @property def middlewares(self) -> Tuple[Web3Middleware, ...]: return self._middlewares diff --git a/web3/providers/base.py b/web3/providers/base.py index 679cda2e68..a0e59d8eca 100644 --- a/web3/providers/base.py +++ b/web3/providers/base.py @@ -61,15 +61,19 @@ class BaseProvider: Tuple[Web3Middleware, ...], Callable[..., RPCResponse] ] = (None, None) - _request_cache: SimpleCache = SimpleCache(size=500) - _cache_allowed_requests: bool = True - _cacheable_requests: Set[RPCEndpoint] = CACHEABLE_REQUESTS - is_async = False has_persistent_connection = False global_ccip_read_enabled: bool = True ccip_read_max_redirects: int = 4 + # request caching + cache_allowed_requests: bool = False + cacheable_requests: Set[RPCEndpoint] = CACHEABLE_REQUESTS + _request_cache: SimpleCache + + def __init__(self) -> None: + self._request_cache = SimpleCache(1000) + @property def middlewares(self) -> Tuple[Web3Middleware, ...]: return self._middlewares @@ -115,6 +119,7 @@ def is_connected(self, show_traceback: bool = False) -> bool: class JSONBaseProvider(BaseProvider): def __init__(self) -> None: self.request_counter = itertools.count() + super().__init__() def decode_rpc_response(self, raw_response: bytes) -> RPCResponse: text_response = to_text(raw_response) diff --git a/web3/providers/websocket/request_processor.py b/web3/providers/websocket/request_processor.py index e331e41012..1b938ab597 100644 --- a/web3/providers/websocket/request_processor.py +++ b/web3/providers/websocket/request_processor.py @@ -65,7 +65,19 @@ def cache_request_information( method: RPCEndpoint, params: Any, response_formatters: Tuple[Callable[..., Any], ...], - ) -> str: + ) -> Optional[str]: + cached_requests_key = generate_cache_key((method, params)) + if cached_requests_key in self._provider._request_cache._data: + cached_response = self._provider._request_cache._data[cached_requests_key] + cached_response_id = cached_response.get("id") + cache_key = generate_cache_key(cached_response_id) + if cache_key in self._request_information_cache: + self._provider.logger.debug( + "This is a cached request, not caching request info because it is " + f"not unique:\n method={method},\n params={params}" + ) + return None + # copy the request counter and find the next request id without incrementing # since this is done when / if the request is successfully sent request_id = next(copy(self._provider.request_counter)) @@ -147,11 +159,20 @@ def get_request_information_for_response( else: # retrieve the request info from the cache using the request id cache_key = generate_cache_key(response["id"]) - request_info = ( - # pop the request info from the cache since we don't need to keep it - # this keeps the cache size bounded - self.pop_cached_request_information(cache_key) - ) + if response in self._provider._request_cache._data.values(): + request_info = ( + # don't pop the request info from the cache, since we need to keep + # it to process future responses + # i.e. request information remains in the cache + self._request_information_cache.get_cache_entry(cache_key) + ) + else: + request_info = ( + # pop the request info from the cache since we don't need to keep it + # this keeps the cache size bounded + self.pop_cached_request_information(cache_key) + ) + if ( request_info is not None and request_info.method == "eth_unsubscribe" diff --git a/web3/providers/websocket/websocket_v2.py b/web3/providers/websocket/websocket_v2.py index 09f4e2ce31..bbee3597cd 100644 --- a/web3/providers/websocket/websocket_v2.py +++ b/web3/providers/websocket/websocket_v2.py @@ -1,5 +1,4 @@ import asyncio -import copy import json import logging import os @@ -41,6 +40,8 @@ RPCId, RPCResponse, ) +from web3.utils import SimpleCache + DEFAULT_PING_INTERVAL = 30 # 30 seconds DEFAULT_PING_TIMEOUT = 300 # 5 minutes @@ -100,21 +101,6 @@ def __init__( def __str__(self) -> str: return f"Websocket connection: {self.endpoint_uri}" - @property - def _effective_request_number(self) -> int: - current_request_id = next(copy.deepcopy(self.request_counter)) - 1 - cache_key = generate_cache_key(current_request_id) - requests_info = self._request_processor._request_information_cache._data - - if cache_key not in requests_info: - return current_request_id - - while cache_key in requests_info: - current_request_id += 1 - cache_key = generate_cache_key(current_request_id) - - return current_request_id - async def is_connected(self, show_traceback: bool = False) -> bool: if not self._ws: return False From 3d6ac9d4dd5c3f1076903b31adff9c0b40f3fde1 Mon Sep 17 00:00:00 2001 From: fselmo Date: Wed, 29 Nov 2023 13:42:55 -0700 Subject: [PATCH 07/24] WIP: Fix or comment out core test errors so core suite can run --- tests/core/eth-module/test_transactions.py | 25 +- .../test_time_based_gas_price_strategy.py | 47 +- .../core/manager/test_default_middlewares.py | 19 +- .../method-class/test_result_formatters.py | 15 +- .../middleware/test_attrdict_middleware.py | 98 ++--- .../middleware/test_eth_tester_middleware.py | 5 +- .../core/middleware/test_filter_middleware.py | 42 +- .../middleware/test_formatting_middleware.py | 105 ++--- .../middleware/test_http_request_retry.py | 403 +++++++++--------- .../test_name_to_address_middleware.py | 8 +- tests/core/middleware/test_stalecheck.py | 21 +- .../middleware/test_transaction_signing.py | 47 +- .../providers/test_async_http_provider.py | 22 +- tests/core/providers/test_http_provider.py | 4 +- tests/core/providers/test_ipc_provider.py | 21 +- web3/_utils/module_testing/utils.py | 39 +- web3/auto/gethdev.py | 4 +- web3/middleware/__init__.py | 2 + web3/middleware/signing.py | 3 + web3/middleware/stalecheck.py | 3 + web3/providers/eth_tester/main.py | 7 +- 21 files changed, 430 insertions(+), 510 deletions(-) diff --git a/tests/core/eth-module/test_transactions.py b/tests/core/eth-module/test_transactions.py index ce09c471d0..1fa2b78e16 100644 --- a/tests/core/eth-module/test_transactions.py +++ b/tests/core/eth-module/test_transactions.py @@ -23,12 +23,6 @@ TransactionNotFound, Web3ValidationError, ) -from web3.middleware import ( - construct_result_generator_middleware, -) -from web3.middleware.simulate_unmined_transaction import ( - unmined_receipt_simulator_middleware, -) RECEIPT_TIMEOUT = 0.2 @@ -177,7 +171,7 @@ def test_passing_string_to_to_hex(w3): def test_unmined_transaction_wait_for_receipt(w3): - w3.middleware_onion.add(unmined_receipt_simulator_middleware) + # w3.middleware_onion.add(unmined_receipt_simulator_middleware) txn_hash = w3.eth.send_transaction( { "from": w3.eth.coinbase, @@ -193,7 +187,7 @@ def test_unmined_transaction_wait_for_receipt(w3): assert txn_receipt["blockHash"] is not None -def test_get_transaction_formatters(w3): +def test_get_transaction_formatters(w3, request_mocker): non_checksummed_addr = "0xB2930B35844A230F00E51431ACAE96FE543A0347" # all uppercase unformatted_transaction = { "blockHash": ( @@ -236,15 +230,11 @@ def test_get_transaction_formatters(w3): "data": "0x5b34b966", } - result_middleware = construct_result_generator_middleware( - { - RPC.eth_getTransactionByHash: lambda *_: unformatted_transaction, - } - ) - w3.middleware_onion.inject(result_middleware, "result_middleware", layer=0) - - # test against eth_getTransactionByHash - received_tx = w3.eth.get_transaction("") + with request_mocker( + w3, mock_results={RPC.eth_getTransactionByHash: unformatted_transaction} + ): + # test against eth_getTransactionByHash + received_tx = w3.eth.get_transaction("") checksummed_addr = to_checksum_address(non_checksummed_addr) assert non_checksummed_addr != checksummed_addr @@ -299,4 +289,3 @@ def test_get_transaction_formatters(w3): ) assert received_tx == expected - w3.middleware_onion.remove("result_middleware") diff --git a/tests/core/gas-strategies/test_time_based_gas_price_strategy.py b/tests/core/gas-strategies/test_time_based_gas_price_strategy.py index f0f159799a..20ac7b5a07 100644 --- a/tests/core/gas-strategies/test_time_based_gas_price_strategy.py +++ b/tests/core/gas-strategies/test_time_based_gas_price_strategy.py @@ -10,9 +10,6 @@ from web3.gas_strategies.time_based import ( construct_time_based_gas_price_strategy, ) -from web3.middleware import ( - construct_result_generator_middleware, -) from web3.providers.base import ( BaseProvider, ) @@ -151,16 +148,16 @@ def _get_block_by_something(method, params): ), ) def test_time_based_gas_price_strategy(strategy_params, expected): - fixture_middleware = construct_result_generator_middleware( - { - "eth_getBlockByHash": _get_block_by_something, - "eth_getBlockByNumber": _get_block_by_something, - } - ) + # fixture_middleware = construct_result_generator_middleware( + # { + # "eth_getBlockByHash": _get_block_by_something, + # "eth_getBlockByNumber": _get_block_by_something, + # } + # ) w3 = Web3( provider=BaseProvider(), - middlewares=[fixture_middleware], + # middlewares=[fixture_middleware], ) time_based_gas_price_strategy = construct_time_based_gas_price_strategy( @@ -187,17 +184,17 @@ def _get_gas_price(method, params): def test_time_based_gas_price_strategy_without_transactions(): - fixture_middleware = construct_result_generator_middleware( - { - "eth_getBlockByHash": _get_initial_block, - "eth_getBlockByNumber": _get_initial_block, - "eth_gasPrice": _get_gas_price, - } - ) + # fixture_middleware = construct_result_generator_middleware( + # { + # "eth_getBlockByHash": _get_initial_block, + # "eth_getBlockByNumber": _get_initial_block, + # "eth_gasPrice": _get_gas_price, + # } + # ) w3 = Web3( provider=BaseProvider(), - middlewares=[fixture_middleware], + # middlewares=[fixture_middleware], ) time_based_gas_price_strategy = construct_time_based_gas_price_strategy( @@ -272,16 +269,16 @@ def test_time_based_gas_price_strategy_zero_sample( strategy_params_zero, expected_exception_message ): with pytest.raises(Web3ValidationError) as excinfo: - fixture_middleware = construct_result_generator_middleware( - { - "eth_getBlockByHash": _get_block_by_something, - "eth_getBlockByNumber": _get_block_by_something, - } - ) + # fixture_middleware = construct_result_generator_middleware( + # { + # "eth_getBlockByHash": _get_block_by_something, + # "eth_getBlockByNumber": _get_block_by_something, + # } + # ) w3 = Web3( provider=BaseProvider(), - middlewares=[fixture_middleware], + # middlewares=[fixture_middleware], ) time_based_gas_price_strategy_zero = construct_time_based_gas_price_strategy( **strategy_params_zero, diff --git a/tests/core/manager/test_default_middlewares.py b/tests/core/manager/test_default_middlewares.py index 64b8a5d1d2..bda6c1d3e8 100644 --- a/tests/core/manager/test_default_middlewares.py +++ b/tests/core/manager/test_default_middlewares.py @@ -3,15 +3,10 @@ ) from web3.middleware import ( abi_middleware, - async_attrdict_middleware, - async_buffered_gas_estimate_middleware, - async_gas_price_strategy_middleware, - async_name_to_address_middleware, - async_validation_middleware, attrdict_middleware, buffered_gas_estimate_middleware, + ens_name_to_address_middleware, gas_price_strategy_middleware, - name_to_address_middleware, validation_middleware, ) @@ -19,7 +14,7 @@ def test_default_sync_middlewares(w3): expected_middlewares = [ (gas_price_strategy_middleware, "gas_price_strategy"), - (name_to_address_middleware(w3), "name_to_address"), + (ens_name_to_address_middleware, "name_to_address"), (attrdict_middleware, "attrdict"), (validation_middleware, "validation"), (abi_middleware, "abi"), @@ -35,11 +30,11 @@ def test_default_sync_middlewares(w3): def test_default_async_middlewares(): expected_middlewares = [ - (async_gas_price_strategy_middleware, "gas_price_strategy"), - (async_name_to_address_middleware, "name_to_address"), - (async_attrdict_middleware, "attrdict"), - (async_validation_middleware, "validation"), - (async_buffered_gas_estimate_middleware, "gas_estimate"), + (gas_price_strategy_middleware, "gas_price_strategy"), + (ens_name_to_address_middleware, "name_to_address"), + (attrdict_middleware, "attrdict"), + (validation_middleware, "validation"), + (buffered_gas_estimate_middleware, "gas_estimate"), ] default_middlewares = RequestManager.async_default_middlewares() diff --git a/tests/core/method-class/test_result_formatters.py b/tests/core/method-class/test_result_formatters.py index 324cfc55e5..9bea0eb8d9 100644 --- a/tests/core/method-class/test_result_formatters.py +++ b/tests/core/method-class/test_result_formatters.py @@ -10,9 +10,6 @@ from web3.method import ( Method, ) -from web3.middleware.fixture import ( - construct_result_generator_middleware, -) from web3.module import ( Module, ) @@ -33,11 +30,7 @@ def make_request(method, params): raise NotImplementedError -result_middleware = construct_result_generator_middleware( - { - "method_for_test": lambda m, p: "ok", - } -) +result_for_test = {"method_for_test": "ok"} class ModuleForTest(Module): @@ -48,11 +41,11 @@ class ModuleForTest(Module): def dummy_w3(): w3 = Web3( DummyProvider(), - middlewares=[result_middleware], modules={"module": ModuleForTest}, ) return w3 -def test_result_formatter(dummy_w3): - assert dummy_w3.module.method() == "OKAY" +def test_result_formatter(dummy_w3, request_mocker): + with request_mocker(dummy_w3, mock_results=result_for_test): + assert dummy_w3.module.method() == "OKAY" diff --git a/tests/core/middleware/test_attrdict_middleware.py b/tests/core/middleware/test_attrdict_middleware.py index 0a36654fb3..da03afe326 100644 --- a/tests/core/middleware/test_attrdict_middleware.py +++ b/tests/core/middleware/test_attrdict_middleware.py @@ -9,25 +9,18 @@ AttributeDict, ) from web3.middleware import ( - async_construct_result_generator_middleware, attrdict_middleware, - construct_result_generator_middleware, ) from web3.providers.eth_tester import ( AsyncEthereumTesterProvider, ) -from web3.types import ( - RPCEndpoint, -) GENERATED_NESTED_DICT_RESULT = { - "result": { - "a": 1, - "b": { - "b1": 1, - "b2": {"b2a": 1, "b2b": {"b2b1": 1, "b2b2": {"test": "fin"}}}, - }, - } + "a": 1, + "b": { + "b1": 1, + "b2": {"b2a": 1, "b2b": {"b2b1": 1, "b2b2": {"test": "fin"}}}, + }, } @@ -41,19 +34,14 @@ def test_attrdict_middleware_default_for_ethereum_tester_provider(): assert w3.middleware_onion.get("attrdict") == attrdict_middleware -def test_attrdict_middleware_is_recursive(w3): - w3.middleware_onion.inject( - construct_result_generator_middleware( - {RPCEndpoint("fake_endpoint"): lambda *_: GENERATED_NESTED_DICT_RESULT} - ), - "result_gen", - layer=0, - ) - response = w3.manager.request_blocking("fake_endpoint", []) +def test_attrdict_middleware_is_recursive(w3, request_mocker): + with request_mocker( + w3, + mock_results={"fake_endpoint": GENERATED_NESTED_DICT_RESULT}, + ): + result = w3.manager.request_blocking("fake_endpoint", []) - result = response["result"] assert isinstance(result, AttributeDict) - assert response.result == result assert isinstance(result["b"], AttributeDict) assert result.b == result["b"] @@ -64,27 +52,17 @@ def test_attrdict_middleware_is_recursive(w3): assert isinstance(result.b.b2.b2b["b2b2"], AttributeDict) assert result.b.b2.b2b.b2b2 == result.b.b2.b2b["b2b2"] - # cleanup - w3.middleware_onion.remove("result_gen") - -def test_no_attrdict_middleware_does_not_convert_dicts_to_attrdict(): +def test_no_attrdict_middleware_does_not_convert_dicts_to_attrdict(request_mocker): w3 = Web3(EthereumTesterProvider()) - - w3.middleware_onion.inject( - construct_result_generator_middleware( - {RPCEndpoint("fake_endpoint"): lambda *_: GENERATED_NESTED_DICT_RESULT} - ), - "result_gen", - layer=0, - ) - # remove attrdict middleware w3.middleware_onion.remove("attrdict") - response = w3.manager.request_blocking("fake_endpoint", []) - - result = response["result"] + with request_mocker( + w3, + mock_results={"fake_endpoint": GENERATED_NESTED_DICT_RESULT}, + ): + result = w3.manager.request_blocking("fake_endpoint", []) _assert_dict_and_not_attrdict(result) _assert_dict_and_not_attrdict(result["b"]) @@ -103,19 +81,14 @@ async def test_async_attrdict_middleware_default_for_async_ethereum_tester_provi @pytest.mark.asyncio -async def test_async_attrdict_middleware_is_recursive(async_w3): - async_w3.middleware_onion.inject( - await async_construct_result_generator_middleware( - {RPCEndpoint("fake_endpoint"): lambda *_: GENERATED_NESTED_DICT_RESULT} - ), - "result_gen", - layer=0, - ) - response = await async_w3.manager.coro_request("fake_endpoint", []) - - result = response["result"] +async def test_async_attrdict_middleware_is_recursive(async_w3, request_mocker): + async with request_mocker( + async_w3, + mock_results={"fake_endpoint": GENERATED_NESTED_DICT_RESULT}, + ): + result = await async_w3.manager.coro_request("fake_endpoint", []) + assert isinstance(result, AttributeDict) - assert response.result == result assert isinstance(result["b"], AttributeDict) assert result.b == result["b"] @@ -126,28 +99,21 @@ async def test_async_attrdict_middleware_is_recursive(async_w3): assert isinstance(result.b.b2.b2b["b2b2"], AttributeDict) assert result.b.b2.b2b.b2b2 == result.b.b2.b2b["b2b2"] - # cleanup - async_w3.middleware_onion.remove("result_gen") - @pytest.mark.asyncio -async def test_no_async_attrdict_middleware_does_not_convert_dicts_to_attrdict(): +async def test_no_async_attrdict_middleware_does_not_convert_dicts_to_attrdict( + request_mocker, +): async_w3 = AsyncWeb3(AsyncEthereumTesterProvider()) - async_w3.middleware_onion.inject( - await async_construct_result_generator_middleware( - {RPCEndpoint("fake_endpoint"): lambda *_: GENERATED_NESTED_DICT_RESULT} - ), - "result_gen", - layer=0, - ) - # remove attrdict middleware async_w3.middleware_onion.remove("attrdict") - response = await async_w3.manager.coro_request("fake_endpoint", []) - - result = response["result"] + async with request_mocker( + async_w3, + mock_results={"fake_endpoint": GENERATED_NESTED_DICT_RESULT}, + ): + result = await async_w3.manager.coro_request("fake_endpoint", []) _assert_dict_and_not_attrdict(result) _assert_dict_and_not_attrdict(result["b"]) diff --git a/tests/core/middleware/test_eth_tester_middleware.py b/tests/core/middleware/test_eth_tester_middleware.py index 6797a51fe6..819376daa0 100644 --- a/tests/core/middleware/test_eth_tester_middleware.py +++ b/tests/core/middleware/test_eth_tester_middleware.py @@ -4,7 +4,6 @@ ) from web3.providers.eth_tester.middleware import ( - async_default_transaction_fields_middleware, default_transaction_fields_middleware, ) from web3.types import ( @@ -178,9 +177,7 @@ async def mock_async_coinbase(): mock_w3.eth.accounts = mock_async_accounts() mock_w3.eth.coinbase = mock_async_coinbase() - middleware = await async_default_transaction_fields_middleware( - mock_request, mock_w3 - ) + middleware = default_transaction_fields_middleware(mock_request, mock_w3) base_params = {"chainId": 5} filled_transaction = await middleware(method, [base_params]) diff --git a/tests/core/middleware/test_filter_middleware.py b/tests/core/middleware/test_filter_middleware.py index a65ea155f0..7ade451297 100644 --- a/tests/core/middleware/test_filter_middleware.py +++ b/tests/core/middleware/test_filter_middleware.py @@ -15,11 +15,8 @@ AsyncEth, ) from web3.middleware import ( - async_attrdict_middleware, - async_construct_result_generator_middleware, async_local_filter_middleware, attrdict_middleware, - construct_result_generator_middleware, local_filter_middleware, ) from web3.middleware.filter import ( @@ -89,14 +86,15 @@ def iterator(): @pytest.fixture(scope="function") def result_generator_middleware(iter_block_number): - return construct_result_generator_middleware( - { - "eth_getLogs": lambda *_: FILTER_LOG, - "eth_getBlockByNumber": lambda *_: {"hash": BLOCK_HASH}, - "net_version": lambda *_: 1, - "eth_blockNumber": lambda *_: next(iter_block_number), - } - ) + return None + # return construct_result_generator_middleware( + # { + # "eth_getLogs": lambda *_: FILTER_LOG, + # "eth_getBlockByNumber": lambda *_: {"hash": BLOCK_HASH}, + # "net_version": lambda *_: 1, + # "eth_blockNumber": lambda *_: next(iter_block_number), + # } + # ) @pytest.fixture(scope="function") @@ -267,14 +265,15 @@ async def make_request(self, method, params): @pytest_asyncio.fixture(scope="function") async def async_result_generator_middleware(iter_block_number): - return await async_construct_result_generator_middleware( - { - "eth_getLogs": lambda *_: FILTER_LOG, - "eth_getBlockByNumber": lambda *_: {"hash": BLOCK_HASH}, - "net_version": lambda *_: 1, - "eth_blockNumber": lambda *_: next(iter_block_number), - } - ) + return None + # return await async_construct_result_generator_middleware( + # { + # "eth_getLogs": lambda *_: FILTER_LOG, + # "eth_getBlockByNumber": lambda *_: {"hash": BLOCK_HASH}, + # "net_version": lambda *_: 1, + # "eth_blockNumber": lambda *_: next(iter_block_number), + # } + # ) @pytest.fixture(scope="function") @@ -286,9 +285,8 @@ def async_w3_base(): @pytest.fixture(scope="function") def async_w3(async_w3_base, async_result_generator_middleware): - async_w3_base.middleware_onion.add(async_result_generator_middleware) - async_w3_base.middleware_onion.add(async_attrdict_middleware) - async_w3_base.middleware_onion.add(async_local_filter_middleware) + async_w3_base.middleware_onion.add(attrdict_middleware) + async_w3_base.middleware_onion.add(local_filter_middleware) return async_w3_base diff --git a/tests/core/middleware/test_formatting_middleware.py b/tests/core/middleware/test_formatting_middleware.py index 8f7f9ca807..73f4f34a0e 100644 --- a/tests/core/middleware/test_formatting_middleware.py +++ b/tests/core/middleware/test_formatting_middleware.py @@ -6,11 +6,6 @@ from web3 import ( Web3, ) -from web3.middleware import ( - construct_error_generator_middleware, - construct_formatting_middleware, - construct_result_generator_middleware, -) from web3.providers.base import ( BaseProvider, ) @@ -29,24 +24,16 @@ def w3(): return Web3(provider=DummyProvider(), middlewares=[]) -def test_formatting_middleware(w3): +def test_formatting_middleware(w3, request_mocker): # No formatters by default - w3.middleware_onion.add(construct_formatting_middleware()) - w3.middleware_onion.add( - construct_result_generator_middleware( - { - "test_endpoint": lambda method, params: "done", - } - ) - ) - expected = "done" - actual = w3.manager.request_blocking("test_endpoint", []) - assert actual == expected + with request_mocker(w3, mock_results={"test_endpoint": "done"}): + actual = w3.provider.make_request("test_endpoint", []) + assert actual == expected def test_formatting_middleware_no_method(w3): - w3.middleware_onion.add(construct_formatting_middleware()) + # w3.middleware_onion.add(construct_formatting_middleware()) # Formatting middleware requires an endpoint with pytest.raises(NotImplementedError): @@ -55,17 +42,17 @@ def test_formatting_middleware_no_method(w3): def test_formatting_middleware_request_formatters(w3): callable_mock = Mock() - w3.middleware_onion.add( - construct_result_generator_middleware( - {RPCEndpoint("test_endpoint"): lambda method, params: "done"} - ) - ) - - w3.middleware_onion.add( - construct_formatting_middleware( - request_formatters={RPCEndpoint("test_endpoint"): callable_mock} - ) - ) + # w3.middleware_onion.add( + # construct_result_generator_middleware( + # {RPCEndpoint("test_endpoint"): lambda method, params: "done"} + # ) + # ) + # + # w3.middleware_onion.add( + # construct_formatting_middleware( + # request_formatters={RPCEndpoint("test_endpoint"): callable_mock} + # ) + # ) expected = "done" actual = w3.manager.request_blocking("test_endpoint", ["param1"]) @@ -75,16 +62,16 @@ def test_formatting_middleware_request_formatters(w3): def test_formatting_middleware_result_formatters(w3): - w3.middleware_onion.add( - construct_result_generator_middleware( - {RPCEndpoint("test_endpoint"): lambda method, params: "done"} - ) - ) - w3.middleware_onion.add( - construct_formatting_middleware( - result_formatters={RPCEndpoint("test_endpoint"): lambda x: f"STATUS:{x}"} - ) - ) + # w3.middleware_onion.add( + # construct_result_generator_middleware( + # {RPCEndpoint("test_endpoint"): lambda method, params: "done"} + # ) + # ) + # w3.middleware_onion.add( + # construct_formatting_middleware( + # result_formatters={RPCEndpoint("test_endpoint"): lambda x: f"STATUS:{x}"} + # ) + # ) expected = "STATUS:done" actual = w3.manager.request_blocking("test_endpoint", []) @@ -92,16 +79,16 @@ def test_formatting_middleware_result_formatters(w3): def test_formatting_middleware_result_formatters_for_none(w3): - w3.middleware_onion.add( - construct_result_generator_middleware( - {RPCEndpoint("test_endpoint"): lambda method, params: None} - ) - ) - w3.middleware_onion.add( - construct_formatting_middleware( - result_formatters={RPCEndpoint("test_endpoint"): lambda x: hex(x)} - ) - ) + # w3.middleware_onion.add( + # construct_result_generator_middleware( + # {RPCEndpoint("test_endpoint"): lambda method, params: None} + # ) + # ) + # w3.middleware_onion.add( + # construct_formatting_middleware( + # result_formatters={RPCEndpoint("test_endpoint"): lambda x: hex(x)} + # ) + # ) expected = None actual = w3.manager.request_blocking("test_endpoint", []) @@ -109,16 +96,16 @@ def test_formatting_middleware_result_formatters_for_none(w3): def test_formatting_middleware_error_formatters(w3): - w3.middleware_onion.add( - construct_error_generator_middleware( - {RPCEndpoint("test_endpoint"): lambda method, params: "error"} - ) - ) - w3.middleware_onion.add( - construct_formatting_middleware( - result_formatters={RPCEndpoint("test_endpoint"): lambda x: f"STATUS:{x}"} - ) - ) + # w3.middleware_onion.add( + # construct_error_generator_middleware( + # {RPCEndpoint("test_endpoint"): lambda method, params: "error"} + # ) + # ) + # w3.middleware_onion.add( + # construct_formatting_middleware( + # result_formatters={RPCEndpoint("test_endpoint"): lambda x: f"STATUS:{x}"} + # ) + # ) expected = "error" with pytest.raises(ValueError) as err: diff --git a/tests/core/middleware/test_http_request_retry.py b/tests/core/middleware/test_http_request_retry.py index 02060b5d85..f1cdb4f6d7 100644 --- a/tests/core/middleware/test_http_request_retry.py +++ b/tests/core/middleware/test_http_request_retry.py @@ -1,208 +1,195 @@ -import pytest -from unittest.mock import ( - Mock, - patch, -) - -import aiohttp -import pytest_asyncio -from requests.exceptions import ( - ConnectionError, - HTTPError, - Timeout, - TooManyRedirects, -) - -import web3 -from web3 import ( - AsyncHTTPProvider, - AsyncWeb3, -) -from web3.middleware.exception_retry_request import ( - async_exception_retry_middleware, - check_if_retry_on_failure, - exception_retry_middleware, -) -from web3.providers import ( - HTTPProvider, - IPCProvider, -) - - -@pytest.fixture -def exception_retry_request_setup(): - w3 = Mock() - provider = HTTPProvider() - errors = (ConnectionError, HTTPError, Timeout, TooManyRedirects) - setup = exception_retry_middleware(provider.make_request, w3, errors, 5) - setup.w3 = w3 - return setup - - -def test_check_if_retry_on_failure_false(): - methods = [ - "eth_sendTransaction", - "personal_signAndSendTransaction", - "personal_sendTransaction", - ] - - for method in methods: - assert not check_if_retry_on_failure(method) - - -def test_check_if_retry_on_failure_true(): - method = "eth_getBalance" - assert check_if_retry_on_failure(method) - - -@patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) -def test_check_send_transaction_called_once( - make_post_request_mock, exception_retry_request_setup -): - method = "eth_sendTransaction" - params = [ - { - "to": "0x0", - "value": 1, - } - ] - - with pytest.raises(ConnectionError): - exception_retry_request_setup(method, params) - assert make_post_request_mock.call_count == 1 - - -@patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) -def test_valid_method_retried(make_post_request_mock, exception_retry_request_setup): - method = "eth_getBalance" - params = [] - - with pytest.raises(ConnectionError): - exception_retry_request_setup(method, params) - assert make_post_request_mock.call_count == 5 - - -def test_is_strictly_default_http_middleware(): - w3 = HTTPProvider() - assert "http_retry_request" in w3.middlewares - - w3 = IPCProvider() - assert "http_retry_request" not in w3.middlewares - - -@patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) -def test_check_with_all_middlewares(make_post_request_mock): - provider = HTTPProvider() - w3 = web3.Web3(provider) - with pytest.raises(ConnectionError): - w3.eth.block_number - assert make_post_request_mock.call_count == 5 - - -@patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) -def test_exception_retry_middleware_with_allow_list_kwarg( - make_post_request_mock, exception_retry_request_setup -): - w3 = Mock() - provider = HTTPProvider() - errors = (ConnectionError, HTTPError, Timeout, TooManyRedirects) - setup = exception_retry_middleware( - provider.make_request, w3, errors, 5, allow_list=["test_userProvidedMethod"] - ) - setup.w3 = w3 - with pytest.raises(ConnectionError): - setup("test_userProvidedMethod", []) - assert make_post_request_mock.call_count == 5 - - make_post_request_mock.reset_mock() - with pytest.raises(ConnectionError): - setup("eth_getBalance", []) - assert make_post_request_mock.call_count == 1 - - -# -- async -- # - - -ASYNC_TEST_RETRY_COUNT = 3 - - -@pytest_asyncio.fixture -async def async_exception_retry_request_setup(): - w3 = Mock() - provider = AsyncHTTPProvider() - setup = await async_exception_retry_middleware( - provider.make_request, - w3, - (TimeoutError, aiohttp.ClientError), - ASYNC_TEST_RETRY_COUNT, - 0.1, - ) - setup.w3 = w3 - return setup - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "error", - ( - TimeoutError, - aiohttp.ClientError, # general base class for all aiohttp errors - # ClientError subclasses - aiohttp.ClientConnectionError, - aiohttp.ServerTimeoutError, - aiohttp.ClientOSError, - ), -) -async def test_async_check_retry_middleware(error, async_exception_retry_request_setup): - with patch( - "web3.providers.async_rpc.async_make_post_request" - ) as async_make_post_request_mock: - async_make_post_request_mock.side_effect = error - - with pytest.raises(error): - await async_exception_retry_request_setup("eth_getBalance", []) - assert async_make_post_request_mock.call_count == ASYNC_TEST_RETRY_COUNT - - -@pytest.mark.asyncio -async def test_async_check_without_retry_middleware(): - with patch( - "web3.providers.async_rpc.async_make_post_request" - ) as async_make_post_request_mock: - async_make_post_request_mock.side_effect = TimeoutError - provider = AsyncHTTPProvider() - w3 = AsyncWeb3(provider) - w3.provider._middlewares = () - - with pytest.raises(TimeoutError): - await w3.eth.block_number - assert async_make_post_request_mock.call_count == 1 - - -@pytest.mark.asyncio -async def test_async_exception_retry_middleware_with_allow_list_kwarg(): - w3 = Mock() - provider = AsyncHTTPProvider() - setup = await async_exception_retry_middleware( - provider.make_request, - w3, - (TimeoutError, aiohttp.ClientError), - retries=ASYNC_TEST_RETRY_COUNT, - backoff_factor=0.1, - allow_list=["test_userProvidedMethod"], - ) - setup.w3 = w3 - - with patch( - "web3.providers.async_rpc.async_make_post_request" - ) as async_make_post_request_mock: - async_make_post_request_mock.side_effect = TimeoutError - - with pytest.raises(TimeoutError): - await setup("test_userProvidedMethod", []) - assert async_make_post_request_mock.call_count == ASYNC_TEST_RETRY_COUNT - - async_make_post_request_mock.reset_mock() - with pytest.raises(TimeoutError): - await setup("eth_getBalance", []) - assert async_make_post_request_mock.call_count == 1 +# TODO: Redo these tests but as the provider configuration for retrying requests +# import pytest +# from unittest.mock import ( +# Mock, +# patch, +# ) +# +# import aiohttp +# import pytest_asyncio +# from requests.exceptions import ( +# ConnectionError, +# HTTPError, +# Timeout, +# TooManyRedirects, +# ) +# +# import web3 +# from web3 import ( +# AsyncHTTPProvider, +# AsyncWeb3, +# ) +# from web3.providers import ( +# HTTPProvider, +# IPCProvider, +# ) +# from web3.providers.rpc.utils import check_if_retry_on_failure +# +# +# def test_check_if_retry_on_failure_false(): +# methods = [ +# "eth_sendTransaction", +# "personal_signAndSendTransaction", +# "personal_sendTransaction", +# ] +# +# for method in methods: +# assert not check_if_retry_on_failure(method) +# +# +# def test_check_if_retry_on_failure_true(): +# method = "eth_getBalance" +# assert check_if_retry_on_failure(method) +# +# +# @patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) +# def test_check_send_transaction_called_once( +# make_post_request_mock, exception_retry_request_setup +# ): +# method = "eth_sendTransaction" +# params = [ +# { +# "to": "0x0", +# "value": 1, +# } +# ] +# +# with pytest.raises(ConnectionError): +# exception_retry_request_setup(method, params) +# assert make_post_request_mock.call_count == 1 +# +# +# @patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) +# def test_valid_method_retried(make_post_request_mock, exception_retry_request_setup): +# method = "eth_getBalance" +# params = [] +# +# with pytest.raises(ConnectionError): +# exception_retry_request_setup(method, params) +# assert make_post_request_mock.call_count == 5 +# +# +# def test_is_strictly_default_http_middleware(): +# w3 = HTTPProvider() +# assert "http_retry_request" in w3.middlewares +# +# w3 = IPCProvider() +# assert "http_retry_request" not in w3.middlewares +# +# +# @patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) +# def test_check_with_all_middlewares(make_post_request_mock): +# provider = HTTPProvider() +# w3 = web3.Web3(provider) +# with pytest.raises(ConnectionError): +# w3.eth.block_number +# assert make_post_request_mock.call_count == 5 +# +# +# @patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) +# def test_exception_retry_middleware_with_allow_list_kwarg( +# make_post_request_mock, exception_retry_request_setup +# ): +# w3 = Mock() +# provider = HTTPProvider() +# errors = (ConnectionError, HTTPError, Timeout, TooManyRedirects) +# setup = exception_retry_middleware( +# provider.make_request, w3, errors, 5, allow_list=["test_userProvidedMethod"] +# ) +# setup.w3 = w3 +# with pytest.raises(ConnectionError): +# setup("test_userProvidedMethod", []) +# assert make_post_request_mock.call_count == 5 +# +# make_post_request_mock.reset_mock() +# with pytest.raises(ConnectionError): +# setup("eth_getBalance", []) +# assert make_post_request_mock.call_count == 1 +# +# +# # -- async -- # +# +# +# ASYNC_TEST_RETRY_COUNT = 3 +# +# +# @pytest_asyncio.fixture +# async def async_exception_retry_request_setup(): +# w3 = Mock() +# provider = AsyncHTTPProvider() +# setup = await async_exception_retry_middleware( +# provider.make_request, +# w3, +# (TimeoutError, aiohttp.ClientError), +# ASYNC_TEST_RETRY_COUNT, +# 0.1, +# ) +# setup.w3 = w3 +# return setup +# +# +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# "error", +# ( +# TimeoutError, +# aiohttp.ClientError, # general base class for all aiohttp errors +# # ClientError subclasses +# aiohttp.ClientConnectionError, +# aiohttp.ServerTimeoutError, +# aiohttp.ClientOSError, +# ), +# ) +# async def test_async_check_retry_middleware(error, async_exception_retry_request_setup): +# with patch( +# "web3.providers.async_rpc.async_make_post_request" +# ) as async_make_post_request_mock: +# async_make_post_request_mock.side_effect = error +# +# with pytest.raises(error): +# await async_exception_retry_request_setup("eth_getBalance", []) +# assert async_make_post_request_mock.call_count == ASYNC_TEST_RETRY_COUNT +# +# +# @pytest.mark.asyncio +# async def test_async_check_without_retry_middleware(): +# with patch( +# "web3.providers.async_rpc.async_make_post_request" +# ) as async_make_post_request_mock: +# async_make_post_request_mock.side_effect = TimeoutError +# provider = AsyncHTTPProvider() +# w3 = AsyncWeb3(provider) +# w3.provider._middlewares = () +# +# with pytest.raises(TimeoutError): +# await w3.eth.block_number +# assert async_make_post_request_mock.call_count == 1 +# +# +# @pytest.mark.asyncio +# async def test_async_exception_retry_middleware_with_allow_list_kwarg(): +# w3 = Mock() +# provider = AsyncHTTPProvider() +# setup = await async_exception_retry_middleware( +# provider.make_request, +# w3, +# (TimeoutError, aiohttp.ClientError), +# retries=ASYNC_TEST_RETRY_COUNT, +# backoff_factor=0.1, +# allow_list=["test_userProvidedMethod"], +# ) +# setup.w3 = w3 +# +# with patch( +# "web3.providers.async_rpc.async_make_post_request" +# ) as async_make_post_request_mock: +# async_make_post_request_mock.side_effect = TimeoutError +# +# with pytest.raises(TimeoutError): +# await setup("test_userProvidedMethod", []) +# assert async_make_post_request_mock.call_count == ASYNC_TEST_RETRY_COUNT +# +# async_make_post_request_mock.reset_mock() +# with pytest.raises(TimeoutError): +# await setup("eth_getBalance", []) +# assert async_make_post_request_mock.call_count == 1 diff --git a/tests/core/middleware/test_name_to_address_middleware.py b/tests/core/middleware/test_name_to_address_middleware.py index d6dc040942..6434334f32 100644 --- a/tests/core/middleware/test_name_to_address_middleware.py +++ b/tests/core/middleware/test_name_to_address_middleware.py @@ -11,10 +11,7 @@ NameNotFound, ) from web3.middleware import ( - name_to_address_middleware, -) -from web3.middleware.names import ( - async_name_to_address_middleware, + ens_name_to_address_middleware, ) from web3.providers.eth_tester import ( AsyncEthereumTesterProvider, @@ -50,7 +47,8 @@ def ens_addr_account_balance(ens_mapped_address, _w3_setup): @pytest.fixture def w3(_w3_setup, ens_mapped_address): _w3_setup.ens = TempENS({NAME: ens_mapped_address}) - _w3_setup.middleware_onion.add(name_to_address_middleware(_w3_setup)) + ens_name_to_address_middleware._w3 = _w3_setup + _w3_setup.middleware_onion.add(ens_name_to_address_middleware) return _w3_setup diff --git a/tests/core/middleware/test_stalecheck.py b/tests/core/middleware/test_stalecheck.py index a558b74e95..36c4ba6e3d 100644 --- a/tests/core/middleware/test_stalecheck.py +++ b/tests/core/middleware/test_stalecheck.py @@ -16,7 +16,6 @@ from web3.middleware.stalecheck import ( StaleBlockchain, _is_fresh, - async_make_stalecheck_middleware, ) @@ -33,12 +32,10 @@ def allowable_delay(): @pytest.fixture def request_middleware(allowable_delay): middleware = make_stalecheck_middleware(allowable_delay) - make_request, web3 = Mock(), Mock() - initialized = middleware(make_request, web3) - # for easier mocking, later: - initialized.web3 = web3 - initialized.make_request = make_request - return initialized + web3 = Mock() + middleware._web3 = web3 + middleware._web3.provider.make_request = Mock() + return middleware def stub_block(timestamp): @@ -144,13 +141,11 @@ async def request_async_middleware(allowable_delay): AsyncMock, ) - middleware = await async_make_stalecheck_middleware(allowable_delay) - make_request, web3 = AsyncMock(), AsyncMock() - initialized = await middleware(make_request, web3) + middleware = make_stalecheck_middleware(allowable_delay) + async_web3 = AsyncMock() # for easier mocking, later: - initialized.web3 = web3 - initialized.make_request = make_request - return initialized + middleware._web3 = async_web3 + return middleware @pytest.mark.asyncio diff --git a/tests/core/middleware/test_transaction_signing.py b/tests/core/middleware/test_transaction_signing.py index e1dc9e5a8e..875ea9c72e 100644 --- a/tests/core/middleware/test_transaction_signing.py +++ b/tests/core/middleware/test_transaction_signing.py @@ -36,14 +36,9 @@ InvalidAddress, ) from web3.middleware import ( - async_construct_result_generator_middleware, - construct_result_generator_middleware, construct_sign_and_send_raw_middleware, ) -from web3.middleware.signing import ( - async_construct_sign_and_send_raw_middleware, - gen_normalized_accounts, -) +from web3.middleware.signing import gen_normalized_accounts from web3.providers import ( AsyncBaseProvider, BaseProvider, @@ -98,13 +93,15 @@ def make_request(self, method, params): @pytest.fixture def result_generator_middleware(): - return construct_result_generator_middleware( - { - "eth_sendRawTransaction": lambda *args: args, - "net_version": lambda *_: 1, - "eth_chainId": lambda *_: "0x02", - } - ) + # TODO: replace with request mocker + return None + # return construct_result_generator_middleware( + # { + # "eth_sendRawTransaction": lambda *args: args, + # "net_version": lambda *_: 1, + # "eth_chainId": lambda *_: "0x02", + # } + # ) @pytest.fixture @@ -426,13 +423,15 @@ def test_sign_and_send_raw_middleware_with_byte_addresses( @pytest_asyncio.fixture async def async_result_generator_middleware(): - return await async_construct_result_generator_middleware( - { - "eth_sendRawTransaction": lambda *args: args, - "net_version": lambda *_: 1, - "eth_chainId": lambda *_: "0x02", - } - ) + # TODO: replace with request mocker + return None + # return await async_construct_result_generator_middleware( + # { + # "eth_sendRawTransaction": lambda *args: args, + # "net_version": lambda *_: 1, + # "eth_chainId": lambda *_: "0x02", + # } + # ) class AsyncDummyProvider(AsyncBaseProvider): @@ -482,7 +481,7 @@ async def test_async_sign_and_send_raw_middleware( key_object, ): async_w3_dummy.middleware_onion.add( - await async_construct_sign_and_send_raw_middleware(key_object) + construct_sign_and_send_raw_middleware(key_object) ) legacy_transaction = { @@ -553,9 +552,7 @@ async def test_async_signed_transaction( key_object, from_, ): - async_w3.middleware_onion.add( - await async_construct_sign_and_send_raw_middleware(key_object) - ) + async_w3.middleware_onion.add(construct_sign_and_send_raw_middleware(key_object)) # Drop any falsy addresses accounts = await async_w3.eth.accounts @@ -595,7 +592,7 @@ async def test_async_sign_and_send_raw_middleware_with_byte_addresses( to_ = to_converter(ADDRESS_2) async_w3_dummy.middleware_onion.add( - await async_construct_sign_and_send_raw_middleware(private_key) + construct_sign_and_send_raw_middleware(private_key) ) actual = await async_w3_dummy.manager.coro_request( diff --git a/tests/core/providers/test_async_http_provider.py b/tests/core/providers/test_async_http_provider.py index 52b9397f01..4f1807ba43 100644 --- a/tests/core/providers/test_async_http_provider.py +++ b/tests/core/providers/test_async_http_provider.py @@ -23,16 +23,16 @@ AsyncGethTxPool, ) from web3.middleware import ( - async_attrdict_middleware, - async_buffered_gas_estimate_middleware, - async_gas_price_strategy_middleware, - async_name_to_address_middleware, - async_validation_middleware, + attrdict_middleware, + buffered_gas_estimate_middleware, + gas_price_strategy_middleware, + ens_name_to_address_middleware, + validation_middleware, ) from web3.net import ( AsyncNet, ) -from web3.providers.async_rpc import ( +from web3.providers.rpc import ( AsyncHTTPProvider, ) @@ -86,17 +86,17 @@ def test_web3_with_async_http_provider_has_default_middlewares_and_modules() -> assert ( async_w3.middleware_onion.get("gas_price_strategy") - == async_gas_price_strategy_middleware + == gas_price_strategy_middleware ) assert ( async_w3.middleware_onion.get("name_to_address") - == async_name_to_address_middleware + == ens_name_to_address_middleware ) - assert async_w3.middleware_onion.get("attrdict") == async_attrdict_middleware - assert async_w3.middleware_onion.get("validation") == async_validation_middleware + assert async_w3.middleware_onion.get("attrdict") == attrdict_middleware + assert async_w3.middleware_onion.get("validation") == validation_middleware assert ( async_w3.middleware_onion.get("gas_estimate") - == async_buffered_gas_estimate_middleware + == buffered_gas_estimate_middleware ) diff --git a/tests/core/providers/test_http_provider.py b/tests/core/providers/test_http_provider.py index 53308e3afa..919defd0b6 100644 --- a/tests/core/providers/test_http_provider.py +++ b/tests/core/providers/test_http_provider.py @@ -30,7 +30,7 @@ attrdict_middleware, buffered_gas_estimate_middleware, gas_price_strategy_middleware, - name_to_address_middleware, + ens_name_to_address_middleware, validation_middleware, ) from web3.net import ( @@ -87,7 +87,7 @@ def test_web3_with_http_provider_has_default_middlewares_and_modules() -> None: ) assert ( w3.middleware_onion.get("name_to_address").__name__ - == name_to_address_middleware(w3).__name__ + == ens_name_to_address_middleware(w3).__name__ ) assert w3.middleware_onion.get("attrdict") == attrdict_middleware assert w3.middleware_onion.get("validation") == validation_middleware diff --git a/tests/core/providers/test_ipc_provider.py b/tests/core/providers/test_ipc_provider.py index 2d24e49a2c..6847e91f06 100644 --- a/tests/core/providers/test_ipc_provider.py +++ b/tests/core/providers/test_ipc_provider.py @@ -14,12 +14,10 @@ from web3.exceptions import ( ProviderConnectionError, ) -from web3.middleware import ( - construct_fixture_middleware, -) from web3.providers.ipc import ( IPCProvider, ) +from web3.types import RPCEndpoint @pytest.fixture @@ -90,14 +88,15 @@ def test_sync_waits_for_full_result(jsonrpc_ipc_pipe_path, serve_empty_result): provider._socket.sock.close() -def test_web3_auto_gethdev(): +def test_web3_auto_gethdev(request_mocker): assert isinstance(w3.provider, IPCProvider) - return_block_with_long_extra_data = construct_fixture_middleware( - { - "eth_getBlockByNumber": {"extraData": "0x" + "ff" * 33}, - } - ) - w3.middleware_onion.inject(return_block_with_long_extra_data, layer=0) - block = w3.eth.get_block("latest") + with request_mocker( + w3, + mock_results={ + RPCEndpoint("eth_getBlockByNumber"): {"extraData": "0x" + "ff" * 33} + }, + ): + block = w3.eth.get_block("latest") + assert "extraData" not in block assert block.proofOfAuthorityData == b"\xff" * 33 diff --git a/web3/_utils/module_testing/utils.py b/web3/_utils/module_testing/utils.py index 9dbecb1734..98b815b6c9 100644 --- a/web3/_utils/module_testing/utils.py +++ b/web3/_utils/module_testing/utils.py @@ -3,12 +3,17 @@ Any, Dict, TypeVar, + Union, ) from toolz import ( merge, ) +from web3.providers.eth_tester import ( + AsyncEthereumTesterProvider, + EthereumTesterProvider, +) from web3 import ( AsyncWeb3, Web3, @@ -47,8 +52,8 @@ def test_my_w3(w3, request_mocker): def __init__( self, w3: WEB3, - mock_results: Dict[RPCEndpoint, Dict[str, Any]] = None, - mock_errors: Dict[RPCEndpoint, Dict[str, Any]] = None, + mock_results: Dict[Union[RPCEndpoint, str], Any] = None, + mock_errors: Dict[Union[RPCEndpoint, str], Dict[str, Any]] = None, ): self.w3 = w3 self.mock_results = mock_results or {} @@ -67,17 +72,28 @@ def __exit__(self, exc_type, exc_value, traceback): self.w3.provider._request_func_cache = (None, None) def _mock_request_handler(self, method, params): + if method not in self.mock_errors and method not in self.mock_results: + return self._make_request(method, params) + + request_id = ( + next(copy.deepcopy(self.w3.provider.request_counter)) + if not isinstance(self.w3.provider, EthereumTesterProvider) + else 1 + ) + response_dict = {"jsonrpc": "2.0", "id": request_id} + if method in self.mock_results: - return {"result": self.mock_results[method]} + return merge(response_dict, {"result": self.mock_results[method]}) elif method in self.mock_errors: error = self.mock_errors[method] if not isinstance(error, dict): raise Web3ValidationError("error must be a dict") code = error.get("code", -32000) message = error.get("message", "Mocked error") - return {"error": merge({"code": code, "message": message}, error)} - else: - return self._make_request(method, params) + return merge( + response_dict, + {"error": merge({"code": code, "message": message}, error)}, + ) # -- async -- # async def __aenter__(self): @@ -95,10 +111,13 @@ async def _async_mock_request_handler(self, method, params): if method not in self.mock_errors and method not in self.mock_results: return await self._make_request(method, params) - response_dict = { - "jsonrpc": "2.0", - "id": next(copy.deepcopy(self.w3.provider.request_counter)), - } + request_id = ( + next(copy.deepcopy(self.w3.provider.request_counter)) + if not isinstance(self.w3.provider, AsyncEthereumTesterProvider) + else 1 + ) + response_dict = {"jsonrpc": "2.0", "id": request_id} + if method in self.mock_results: return merge(response_dict, {"result": self.mock_results[method]}) elif method in self.mock_errors: diff --git a/web3/auto/gethdev.py b/web3/auto/gethdev.py index fc28e21999..93134cb77c 100644 --- a/web3/auto/gethdev.py +++ b/web3/auto/gethdev.py @@ -3,11 +3,11 @@ Web3, ) from web3.middleware import ( - extradata_to_poa, + extradata_to_poa_middleware, ) from web3.providers.ipc import ( get_dev_ipc_path, ) w3 = Web3(IPCProvider(get_dev_ipc_path())) -w3.middleware_onion.inject(extradata_to_poa, layer=0) +w3.middleware_onion.inject(extradata_to_poa_middleware, layer=0) diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index 198eed1e29..fc29b46a72 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -47,10 +47,12 @@ pythonic_middleware, ) from .signing import ( + construct_sign_and_send_raw_middleware, SignAndSendRawMiddleware, ) from .stalecheck import ( StaleCheckMiddleware, + make_stalecheck_middleware, ) from .validation import ( validation_middleware, diff --git a/web3/middleware/signing.py b/web3/middleware/signing.py index 2d4805d09e..a3dfc6474e 100644 --- a/web3/middleware/signing.py +++ b/web3/middleware/signing.py @@ -198,3 +198,6 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A raw_tx = account.sign_transaction(filled_transaction).rawTransaction return (raw_tx.hex(),) + + +construct_sign_and_send_raw_middleware = SignAndSendRawMiddleware diff --git a/web3/middleware/stalecheck.py b/web3/middleware/stalecheck.py index 65cacc8956..6e6a78225a 100644 --- a/web3/middleware/stalecheck.py +++ b/web3/middleware/stalecheck.py @@ -72,3 +72,6 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A raise StaleBlockchain(latest, self.allowable_delay) return params + + +make_stalecheck_middleware = StaleCheckMiddleware diff --git a/web3/providers/eth_tester/main.py b/web3/providers/eth_tester/main.py index ab3bc681f7..217e319efa 100644 --- a/web3/providers/eth_tester/main.py +++ b/web3/providers/eth_tester/main.py @@ -23,9 +23,6 @@ from web3.middleware.attrdict import ( attrdict_middleware, ) -from web3.middleware.buffered_gas_estimate import ( - buffered_gas_estimate_middleware, -) from web3.providers import ( BaseProvider, ) @@ -50,8 +47,6 @@ class AsyncEthereumTesterProvider(AsyncBaseProvider): middlewares = ( - attrdict_middleware, - buffered_gas_estimate_middleware, default_transaction_fields_middleware, ethereum_tester_middleware, ) @@ -80,7 +75,6 @@ async def is_connected(self, show_traceback: bool = False) -> Literal[True]: class EthereumTesterProvider(BaseProvider): middlewares = ( - attrdict_middleware, default_transaction_fields_middleware, ethereum_tester_middleware, ) @@ -95,6 +89,7 @@ def __init__( ] = None, ) -> None: # do not import eth_tester until runtime, it is not a default dependency + super().__init__() from eth_tester import EthereumTester # noqa: F811 from eth_tester.backends.base import ( BaseChainBackend, From 3822ad0474eb41e76293a39e3fde2449ce945c92 Mon Sep 17 00:00:00 2001 From: fselmo Date: Thu, 30 Nov 2023 13:49:41 -0700 Subject: [PATCH 08/24] Allow for manipulating the method via middleware - In the middleware refactor, we weren't accounting for manipulating the ``method``, only the params. Allow for this kind of manipulation as it is necessary for ``sign_and_send_raw_middleware``, for example. --- .../middleware/test_transaction_signing.py | 55 +++++-------------- web3/_utils/module_testing/utils.py | 16 +++--- web3/middleware/__init__.py | 24 ++++---- web3/middleware/attrdict.py | 11 ++-- web3/middleware/base.py | 15 +++-- web3/middleware/buffered_gas_estimate.py | 4 +- web3/middleware/formatting.py | 15 +++-- web3/middleware/gas_price_strategy.py | 4 +- web3/middleware/names.py | 2 +- web3/middleware/signing.py | 18 ++++-- web3/middleware/stalecheck.py | 4 +- web3/providers/eth_tester/middleware.py | 4 +- 12 files changed, 76 insertions(+), 96 deletions(-) diff --git a/tests/core/middleware/test_transaction_signing.py b/tests/core/middleware/test_transaction_signing.py index 875ea9c72e..506d533e0b 100644 --- a/tests/core/middleware/test_transaction_signing.py +++ b/tests/core/middleware/test_transaction_signing.py @@ -88,31 +88,21 @@ class DummyProvider(BaseProvider): def make_request(self, method, params): - raise NotImplementedError(f"Cannot make request for {method}:{params}") - - -@pytest.fixture -def result_generator_middleware(): - # TODO: replace with request mocker - return None - # return construct_result_generator_middleware( - # { - # "eth_sendRawTransaction": lambda *args: args, - # "net_version": lambda *_: 1, - # "eth_chainId": lambda *_: "0x02", - # } - # ) - - -@pytest.fixture -def w3_base(): - return Web3(provider=DummyProvider(), middlewares=[]) + raise NotImplementedError(f"Cannot make request for {method}: {params}") @pytest.fixture -def w3_dummy(w3_base, result_generator_middleware): - w3_base.middleware_onion.add(result_generator_middleware) - return w3_base +def w3_dummy(request_mocker): + w3_base = Web3(provider=DummyProvider(), middlewares=[]) + with request_mocker( + w3_base, + mock_results={ + "eth_sendRawTransaction": lambda *args: args, + "net_version": lambda *_: 1, + "eth_chainId": lambda *_: "0x02", + }, + ): + yield w3_base def hex_to_bytes(s): @@ -421,35 +411,16 @@ def test_sign_and_send_raw_middleware_with_byte_addresses( # -- async -- # -@pytest_asyncio.fixture -async def async_result_generator_middleware(): - # TODO: replace with request mocker - return None - # return await async_construct_result_generator_middleware( - # { - # "eth_sendRawTransaction": lambda *args: args, - # "net_version": lambda *_: 1, - # "eth_chainId": lambda *_: "0x02", - # } - # ) - - class AsyncDummyProvider(AsyncBaseProvider): async def coro_request(self, method, params): raise NotImplementedError(f"Cannot make request for {method}:{params}") @pytest.fixture -def async_w3_base(): +def async_w3_dummy(): return AsyncWeb3(provider=AsyncDummyProvider(), middlewares=[]) -@pytest.fixture -def async_w3_dummy(async_w3_base, async_result_generator_middleware): - async_w3_base.middleware_onion.add(async_result_generator_middleware) - return async_w3_base - - @pytest.fixture def async_w3(): return AsyncWeb3(AsyncEthereumTesterProvider()) diff --git a/web3/_utils/module_testing/utils.py b/web3/_utils/module_testing/utils.py index 98b815b6c9..ac0b54664b 100644 --- a/web3/_utils/module_testing/utils.py +++ b/web3/_utils/module_testing/utils.py @@ -1,6 +1,7 @@ import copy from typing import ( Any, + Callable, Dict, TypeVar, Union, @@ -10,10 +11,6 @@ merge, ) -from web3.providers.eth_tester import ( - AsyncEthereumTesterProvider, - EthereumTesterProvider, -) from web3 import ( AsyncWeb3, Web3, @@ -62,8 +59,10 @@ def __init__( def __enter__(self): self.w3.provider.make_request = self._mock_request_handler + # reset request func cache to re-build request_func with mocked make_request self.w3.provider._request_func_cache = (None, None) + return self def __exit__(self, exc_type, exc_value, traceback): @@ -77,13 +76,16 @@ def _mock_request_handler(self, method, params): request_id = ( next(copy.deepcopy(self.w3.provider.request_counter)) - if not isinstance(self.w3.provider, EthereumTesterProvider) + if hasattr(self.w3.provider, "request_counter") else 1 ) response_dict = {"jsonrpc": "2.0", "id": request_id} if method in self.mock_results: - return merge(response_dict, {"result": self.mock_results[method]}) + mock_return = self.mock_results[method] + if isinstance(mock_return, Callable): + mock_return = mock_return(method, params) + return merge(response_dict, {"result": mock_return}) elif method in self.mock_errors: error = self.mock_errors[method] if not isinstance(error, dict): @@ -113,7 +115,7 @@ async def _async_mock_request_handler(self, method, params): request_id = ( next(copy.deepcopy(self.w3.provider.request_counter)) - if not isinstance(self.w3.provider, AsyncEthereumTesterProvider) + if hasattr(self.w3.provider, "request_counter") else 1 ) response_dict = {"jsonrpc": "2.0", "id": request_id} diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index fc29b46a72..ac794419cf 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -77,18 +77,16 @@ def combine_middlewares( response_processors = [ middleware.response_processor for middleware in reversed(middlewares) ] - return lambda method, params_or_response: functools.reduce( - lambda p_o_r, processor: processor(method, p_o_r), - response_processors, - provider_request_fn( - method, - functools.reduce( - lambda p, processor: processor(method, p), - request_processors, - params_or_response, - ), - ), - ) + + def request_fn(method: RPCEndpoint, params: Any) -> RPCResponse: + for processor in request_processors: + method, params = processor(method, params) + response = provider_request_fn(method, params) + for processor in response_processors: + method, response = processor(method, response) + return response + + return request_fn async def async_combine_middlewares( @@ -111,7 +109,7 @@ async def async_combine_middlewares( async def async_request_fn(method: RPCEndpoint, params: Any) -> RPCResponse: for processor in async_request_processors: - params = await processor(method, params) + method, params = await processor(method, params) response = await provider_request_fn(method, params) for processor in async_response_processors: response = await processor(method, response) diff --git a/web3/middleware/attrdict.py b/web3/middleware/attrdict.py index 1a1d4e195a..c85024c267 100644 --- a/web3/middleware/attrdict.py +++ b/web3/middleware/attrdict.py @@ -60,11 +60,12 @@ class AttributeDictMiddleware(Web3Middleware, ABC): def response_processor(self, method: "RPCEndpoint", response: "RPCResponse") -> Any: if "result" in response: - return assoc( - response, "result", AttributeDict.recursive(response["result"]) + return ( + method, + assoc(response, "result", AttributeDict.recursive(response["result"])), ) else: - return response + return (method, response) # -- async -- # @@ -77,9 +78,9 @@ async def async_response_processor( provider._request_processor.append_middleware_response_processor( response, _handle_async_response ) - return response + return (method, response) else: - return _handle_async_response(response) + return (method, _handle_async_response(response)) attrdict_middleware = AttributeDictMiddleware() diff --git a/web3/middleware/base.py b/web3/middleware/base.py index 2573199617..d869a0c074 100644 --- a/web3/middleware/base.py +++ b/web3/middleware/base.py @@ -1,4 +1,5 @@ from typing import ( + Sequence, TYPE_CHECKING, Any, TypeVar, @@ -27,12 +28,10 @@ class Web3Middleware: _w3: WEB3 def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: - return params + return method, params - def response_processor( - self, method: "RPCEndpoint", response: "RPCResponse" - ) -> "RPCResponse": - return response + def response_processor(self, method: "RPCEndpoint", response: "RPCResponse") -> Any: + return (method, response) # -- async -- # @@ -41,11 +40,11 @@ async def async_request_processor( method: "RPCEndpoint", params: Any, ) -> Any: - return params + return method, params async def async_response_processor( self, method: "RPCEndpoint", response: "RPCResponse", - ) -> "RPCResponse": - return response + ) -> Any: + return (method, response) diff --git a/web3/middleware/buffered_gas_estimate.py b/web3/middleware/buffered_gas_estimate.py index 0e79271cdd..436bff712a 100644 --- a/web3/middleware/buffered_gas_estimate.py +++ b/web3/middleware/buffered_gas_estimate.py @@ -42,7 +42,7 @@ def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: hex(get_buffered_gas_estimate(self._w3, transaction)), ) params = (transaction,) - return params + return method, params # -- async -- # @@ -55,7 +55,7 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A ) transaction = assoc(transaction, "gas", hex(gas_estimate)) params = (transaction,) - return params + return method, params buffered_gas_estimate_middleware = BufferedGasEstimateMiddleware() diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index 1d2f50ef9e..0b0961b268 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -138,7 +138,7 @@ def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: formatter = self.request_formatters[method] params = formatter(params) - return params + return method, params def response_processor(self, method: RPCEndpoint, response: "RPCResponse") -> Any: if self.sync_formatters_builder is not None: @@ -149,11 +149,14 @@ def response_processor(self, method: RPCEndpoint, response: "RPCResponse") -> An self.result_formatters = formatters["result_formatters"] self.error_formatters = formatters["error_formatters"] - return _apply_response_formatters( + return ( method, - self.result_formatters, - self.error_formatters, - response, + _apply_response_formatters( + method, + self.result_formatters, + self.error_formatters, + response, + ), ) # -- async -- # @@ -170,7 +173,7 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A formatter = self.request_formatters[method] params = formatter(params) - return params + return method, params async def async_response_processor( self, method: RPCEndpoint, response: "RPCResponse" diff --git a/web3/middleware/gas_price_strategy.py b/web3/middleware/gas_price_strategy.py index 48bf351a49..8255bd0a87 100644 --- a/web3/middleware/gas_price_strategy.py +++ b/web3/middleware/gas_price_strategy.py @@ -97,7 +97,7 @@ def request_processor(self, method: RPCEndpoint, params: Any) -> Any: ) params = (transaction,) - return params + return method, params # -- async -- # @@ -110,7 +110,7 @@ async def async_request_processor(self, method: RPCEndpoint, params: Any) -> Any transaction, latest_block, generated_gas_price ) params = (transaction,) - return params + return method, params gas_price_strategy_middleware = GasPriceStrategyMiddleware() diff --git a/web3/middleware/names.py b/web3/middleware/names.py index bc53614c91..630b402603 100644 --- a/web3/middleware/names.py +++ b/web3/middleware/names.py @@ -136,7 +136,7 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A abi_types_for_method, ) - return params + return method, params ens_name_to_address_middleware = EnsNameToAddressMiddleware() diff --git a/web3/middleware/signing.py b/web3/middleware/signing.py index a3dfc6474e..268214bd56 100644 --- a/web3/middleware/signing.py +++ b/web3/middleware/signing.py @@ -149,7 +149,7 @@ def __init__( def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method != "eth_sendTransaction": - return params + return method, params else: if self.format_and_fill_tx is None: self.format_and_fill_tx = compose( @@ -164,18 +164,21 @@ def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if tx_from is None or ( tx_from is not None and tx_from not in self._accounts ): - return params + return method, params else: account = self._accounts[to_checksum_address(tx_from)] raw_tx = account.sign_transaction(filled_transaction).rawTransaction - return (raw_tx.hex(),) + return ( + RPCEndpoint("eth_sendRawTransaction"), + [raw_tx.hex()], + ) # -- async -- # async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method != "eth_sendTransaction": - return params + return method, params else: formatted_transaction = format_transaction(params[0]) @@ -192,12 +195,15 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A if tx_from is None or ( tx_from is not None and tx_from not in self._accounts ): - return params + return method, params else: account = self._accounts[to_checksum_address(tx_from)] raw_tx = account.sign_transaction(filled_transaction).rawTransaction - return (raw_tx.hex(),) + return ( + RPCEndpoint("eth_sendRawTransaction"), + [raw_tx.hex()], + ) construct_sign_and_send_raw_middleware = SignAndSendRawMiddleware diff --git a/web3/middleware/stalecheck.py b/web3/middleware/stalecheck.py index 6e6a78225a..4a83850827 100644 --- a/web3/middleware/stalecheck.py +++ b/web3/middleware/stalecheck.py @@ -58,7 +58,7 @@ def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: else: raise StaleBlockchain(latest, self.allowable_delay) - return params + return method, params # -- async -- # @@ -71,7 +71,7 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A else: raise StaleBlockchain(latest, self.allowable_delay) - return params + return method, params make_stalecheck_middleware = StaleCheckMiddleware diff --git a/web3/providers/eth_tester/middleware.py b/web3/providers/eth_tester/middleware.py index 6beb3385d5..2882dc68e0 100644 --- a/web3/providers/eth_tester/middleware.py +++ b/web3/providers/eth_tester/middleware.py @@ -381,7 +381,7 @@ def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: fill_default_from, ) params = [filled_transaction] + list(params)[1:] - return params + return method, params # --- async --- # @@ -397,7 +397,7 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A ) params = [filled_transaction] + list(params)[1:] - return params + return method, params ethereum_tester_middleware = FormattingMiddleware( From aea95c14feedd075e2543d9ecafecaf878aa276b Mon Sep 17 00:00:00 2001 From: fselmo Date: Tue, 5 Dec 2023 15:51:28 -0700 Subject: [PATCH 09/24] Re-introduce the sandwiched middleware model via a method on the middleware class - Use a repr() for the middleware onion name when a non-hashable, such as the curried `build` method on a web3middlewarebuilder is used. --- .../test_name_to_address_middleware.py | 2 +- web3/datastructures.py | 13 +++++ web3/middleware/__init__.py | 28 ++++------- web3/middleware/abi.py | 4 +- web3/middleware/attrdict.py | 13 +++-- web3/middleware/base.py | 50 ++++++++++++++++++- web3/middleware/buffered_gas_estimate.py | 2 +- web3/middleware/formatting.py | 40 +++++++++------ web3/middleware/gas_price_strategy.py | 2 +- web3/middleware/names.py | 7 +-- .../normalize_request_parameters.py | 2 +- web3/middleware/proof_of_authority.py | 2 +- web3/middleware/pythonic.py | 2 +- web3/middleware/validation.py | 2 +- web3/providers/eth_tester/middleware.py | 4 +- 15 files changed, 116 insertions(+), 57 deletions(-) diff --git a/tests/core/middleware/test_name_to_address_middleware.py b/tests/core/middleware/test_name_to_address_middleware.py index 6434334f32..7eee54568d 100644 --- a/tests/core/middleware/test_name_to_address_middleware.py +++ b/tests/core/middleware/test_name_to_address_middleware.py @@ -126,7 +126,7 @@ async def async_ens_addr_account_balance(async_ens_mapped_address, _async_w3_set @pytest_asyncio.fixture async def async_w3(_async_w3_setup, async_ens_mapped_address): _async_w3_setup.ens = AsyncTempENS({NAME: async_ens_mapped_address}) - _async_w3_setup.middleware_onion.add(async_name_to_address_middleware) + _async_w3_setup.middleware_onion.add(ens_name_to_address_middleware) return _async_w3_setup diff --git a/web3/datastructures.py b/web3/datastructures.py index 0df838e0eb..a999da1db1 100644 --- a/web3/datastructures.py +++ b/web3/datastructures.py @@ -172,6 +172,12 @@ def add(self, element: TValue, name: Optional[TKey] = None) -> None: if name is None: name = cast(TKey, element) + try: + # handle unhashable types + name.__hash__() + except TypeError: + name = repr(name) + if name in self._queue: if name is element: raise ValueError("You can't add the same un-named instance twice") @@ -206,6 +212,13 @@ def inject( if layer == 0: if name is None: name = cast(TKey, element) + + try: + # handle unhashable types + name.__hash__() + except TypeError: + name = repr(name) + self._queue.move_to_end(name, last=False) elif layer == len(self._queue): return diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index ac794419cf..2beb12e80b 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -1,10 +1,10 @@ -import functools from typing import ( TYPE_CHECKING, Any, Callable, Coroutine, Sequence, + Type, ) from web3.types import ( @@ -63,7 +63,7 @@ def combine_middlewares( - middlewares: Sequence[Web3Middleware], + middlewares: Sequence[Type[Web3Middleware]], w3: "Web3", provider_request_fn: Callable[[RPCEndpoint, Any], Any], ) -> Callable[..., RPCResponse]: @@ -72,25 +72,15 @@ def combine_middlewares( and passes these args through the request processors, makes the request, and passes the response through the response processors. """ - [setattr(middleware, "_w3", w3) for middleware in middlewares] - request_processors = [middleware.request_processor for middleware in middlewares] - response_processors = [ - middleware.response_processor for middleware in reversed(middlewares) - ] - - def request_fn(method: RPCEndpoint, params: Any) -> RPCResponse: - for processor in request_processors: - method, params = processor(method, params) - response = provider_request_fn(method, params) - for processor in response_processors: - method, response = processor(method, response) - return response - - return request_fn + accumulator_fn = provider_request_fn + for middleware in reversed(middlewares): + # initialize the middleware and wrap the accumulator function down the stack + accumulator_fn = middleware(w3)._wrap_make_request(accumulator_fn) + return accumulator_fn async def async_combine_middlewares( - middlewares: Sequence[Web3Middleware], + middlewares: Sequence[Type[Web3Middleware]], async_w3: "AsyncWeb3", provider_request_fn: Callable[[RPCEndpoint, Any], Any], ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: @@ -99,7 +89,7 @@ async def async_combine_middlewares( and passes these args through the request processors, makes the request, and passes the response through the response processors. """ - [setattr(middleware, "_w3", async_w3) for middleware in middlewares] + middlewares = [middleware(async_w3) for middleware in middlewares] async_request_processors = [ middleware.async_request_processor for middleware in middlewares ] diff --git a/web3/middleware/abi.py b/web3/middleware/abi.py index d285893006..af8de03dfb 100644 --- a/web3/middleware/abi.py +++ b/web3/middleware/abi.py @@ -5,4 +5,6 @@ FormattingMiddleware, ) -abi_middleware = FormattingMiddleware(request_formatters=ABI_REQUEST_FORMATTERS) +abi_middleware = FormattingMiddleware.build_middleware( + request_formatters=ABI_REQUEST_FORMATTERS +) diff --git a/web3/middleware/attrdict.py b/web3/middleware/attrdict.py index c85024c267..931e2d06a1 100644 --- a/web3/middleware/attrdict.py +++ b/web3/middleware/attrdict.py @@ -60,12 +60,11 @@ class AttributeDictMiddleware(Web3Middleware, ABC): def response_processor(self, method: "RPCEndpoint", response: "RPCResponse") -> Any: if "result" in response: - return ( - method, - assoc(response, "result", AttributeDict.recursive(response["result"])), + return assoc( + response, "result", AttributeDict.recursive(response["result"]) ) else: - return (method, response) + return response # -- async -- # @@ -78,9 +77,9 @@ async def async_response_processor( provider._request_processor.append_middleware_response_processor( response, _handle_async_response ) - return (method, response) + return response else: - return (method, _handle_async_response(response)) + return _handle_async_response(response) -attrdict_middleware = AttributeDictMiddleware() +attrdict_middleware = AttributeDictMiddleware diff --git a/web3/middleware/base.py b/web3/middleware/base.py index d869a0c074..fa8879b6d8 100644 --- a/web3/middleware/base.py +++ b/web3/middleware/base.py @@ -1,8 +1,10 @@ +from abc import abstractmethod from typing import ( Sequence, TYPE_CHECKING, Any, TypeVar, + Union, ) if TYPE_CHECKING: @@ -27,11 +29,23 @@ class Web3Middleware: _w3: WEB3 + def __init__(self, w3: WEB3) -> None: + self._w3 = w3 + + def _wrap_make_request(self, make_request): + def middleware(method: "RPCEndpoint", params: Any) -> "RPCResponse": + method, params = self.request_processor(method, params) + return self.response_processor(method, make_request(method, params)) + + return middleware + + # -- sync -- # + def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: return method, params def response_processor(self, method: "RPCEndpoint", response: "RPCResponse") -> Any: - return (method, response) + return response # -- async -- # @@ -47,4 +61,36 @@ async def async_response_processor( method: "RPCEndpoint", response: "RPCResponse", ) -> Any: - return (method, response) + return response + + +class Web3MiddlewareBuilder(Web3Middleware): + @staticmethod + @abstractmethod + def build_middleware( + w3: Union["AsyncWeb3", "Web3"], + *args: Any, + **kwargs: Any, + ): + """ + Implementation should initialize the middleware class that implements it, + load it with any of the necessary properties that it needs for processing, + and curry for the ``w3`` argument since it isn't initially present when building + the middleware. + + example implementation: + + ```py + class MyMiddleware(Web3BuilderMiddleware): + + @staticmethod + @curry + def build_middleware(w3, request_formatters=None, response_formatters=None): + middleware = MyMiddleware(w3) + middleware.request_formatters = request_formatters + middleware.response_formatters = response_formatters + + return middleware + ``` + """ + raise NotImplementedError("Must be implemented by subclasses") diff --git a/web3/middleware/buffered_gas_estimate.py b/web3/middleware/buffered_gas_estimate.py index 436bff712a..11a3dede40 100644 --- a/web3/middleware/buffered_gas_estimate.py +++ b/web3/middleware/buffered_gas_estimate.py @@ -58,4 +58,4 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A return method, params -buffered_gas_estimate_middleware = BufferedGasEstimateMiddleware() +buffered_gas_estimate_middleware = BufferedGasEstimateMiddleware diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index 0b0961b268..f22f4fcdad 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -4,6 +4,7 @@ Callable, Coroutine, Optional, + Union, cast, ) @@ -14,7 +15,7 @@ ) from web3.middleware.base import ( - Web3Middleware, + Web3MiddlewareBuilder, ) from web3.types import ( EthSubscriptionParams, @@ -91,9 +92,17 @@ def _format_response( ] -class FormattingMiddleware(Web3Middleware): - def __init__( - self, +class FormattingMiddleware(Web3MiddlewareBuilder): + request_formatters = None + result_formatters = None + error_formatters = None + sync_formatters_builder = None + async_formatters_builder = None + + @staticmethod + @curry + def build_middleware( + w3: Union["AsyncWeb3", "Web3"], # formatters option: request_formatters: Optional[Formatters] = None, result_formatters: Optional[Formatters] = None, @@ -120,11 +129,13 @@ def __init__( "Cannot specify formatters_builder and formatters at the same time" ) - self.request_formatters = request_formatters or {} - self.result_formatters = result_formatters or {} - self.error_formatters = error_formatters or {} - self.sync_formatters_builder = sync_formatters_builder - self.async_formatters_builder = async_formatters_builder + middleware = FormattingMiddleware(w3) + middleware.request_formatters = request_formatters or {} + middleware.result_formatters = result_formatters or {} + middleware.error_formatters = error_formatters or {} + middleware.sync_formatters_builder = sync_formatters_builder + middleware.async_formatters_builder = async_formatters_builder + return middleware def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if self.sync_formatters_builder is not None: @@ -149,14 +160,11 @@ def response_processor(self, method: RPCEndpoint, response: "RPCResponse") -> An self.result_formatters = formatters["result_formatters"] self.error_formatters = formatters["error_formatters"] - return ( + return _apply_response_formatters( method, - _apply_response_formatters( - method, - self.result_formatters, - self.error_formatters, - response, - ), + self.result_formatters, + self.error_formatters, + response, ) # -- async -- # diff --git a/web3/middleware/gas_price_strategy.py b/web3/middleware/gas_price_strategy.py index 8255bd0a87..75e581d1b2 100644 --- a/web3/middleware/gas_price_strategy.py +++ b/web3/middleware/gas_price_strategy.py @@ -113,4 +113,4 @@ async def async_request_processor(self, method: RPCEndpoint, params: Any) -> Any return method, params -gas_price_strategy_middleware = GasPriceStrategyMiddleware() +gas_price_strategy_middleware = GasPriceStrategyMiddleware diff --git a/web3/middleware/names.py b/web3/middleware/names.py index 630b402603..75239d0a12 100644 --- a/web3/middleware/names.py +++ b/web3/middleware/names.py @@ -105,10 +105,11 @@ def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: normalizers = [ abi_ens_resolver(self._w3), ] - self._formatting_middleware = FormattingMiddleware( + self._formatting_middleware = FormattingMiddleware.build_middleware( request_formatters=abi_request_formatters(normalizers, RPC_ABIS) # type: ignore # noqa: E501 ) - return self._formatting_middleware.request_processor(method, params) + + return self._formatting_middleware(self._w3).request_processor(method, params) # -- async -- # @@ -139,4 +140,4 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A return method, params -ens_name_to_address_middleware = EnsNameToAddressMiddleware() +ens_name_to_address_middleware = EnsNameToAddressMiddleware diff --git a/web3/middleware/normalize_request_parameters.py b/web3/middleware/normalize_request_parameters.py index 66c8618df1..2f4bf0d983 100644 --- a/web3/middleware/normalize_request_parameters.py +++ b/web3/middleware/normalize_request_parameters.py @@ -6,6 +6,6 @@ FormattingMiddleware, ) -request_parameter_normalizer = FormattingMiddleware( +request_parameter_normalizer = FormattingMiddleware.build_middleware( request_formatters=METHOD_NORMALIZERS, ) diff --git a/web3/middleware/proof_of_authority.py b/web3/middleware/proof_of_authority.py index d162b295d4..6e81360c18 100644 --- a/web3/middleware/proof_of_authority.py +++ b/web3/middleware/proof_of_authority.py @@ -49,7 +49,7 @@ geth_poa_cleanup = compose(pythonic_geth_poa, remap_geth_poa_fields) -extradata_to_poa_middleware = FormattingMiddleware( +extradata_to_poa_middleware = FormattingMiddleware.build_middleware( result_formatters={ RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup), RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup), diff --git a/web3/middleware/pythonic.py b/web3/middleware/pythonic.py index f83ab5568a..59b5f8a958 100644 --- a/web3/middleware/pythonic.py +++ b/web3/middleware/pythonic.py @@ -6,7 +6,7 @@ FormattingMiddleware, ) -pythonic_middleware = FormattingMiddleware( +pythonic_middleware = FormattingMiddleware.build_middleware( request_formatters=PYTHONIC_REQUEST_FORMATTERS, result_formatters=PYTHONIC_RESULT_FORMATTERS, ) diff --git a/web3/middleware/validation.py b/web3/middleware/validation.py index cbd08067db..5551ac0a3d 100644 --- a/web3/middleware/validation.py +++ b/web3/middleware/validation.py @@ -160,7 +160,7 @@ async def async_build_method_validators( return _build_formatters_dict(request_formatters) -validation_middleware = FormattingMiddleware( +validation_middleware = FormattingMiddleware.build_middleware( sync_formatters_builder=build_method_validators, async_formatters_builder=async_build_method_validators, ) diff --git a/web3/providers/eth_tester/middleware.py b/web3/providers/eth_tester/middleware.py index 2882dc68e0..1fa133e889 100644 --- a/web3/providers/eth_tester/middleware.py +++ b/web3/providers/eth_tester/middleware.py @@ -400,7 +400,7 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A return method, params -ethereum_tester_middleware = FormattingMiddleware( +ethereum_tester_middleware = FormattingMiddleware.build_middleware( request_formatters=request_formatters, result_formatters=result_formatters ) -default_transaction_fields_middleware = DefaultTransactionFieldsMiddleware() +default_transaction_fields_middleware = DefaultTransactionFieldsMiddleware From 6338ddb50e0109902e777f033ef4f7e5702a5858 Mon Sep 17 00:00:00 2001 From: fselmo Date: Wed, 6 Dec 2023 18:14:23 -0700 Subject: [PATCH 10/24] Build on the wrapped make_request middleware refactor - Add support back for ``local_filter_middleware`` by overriding the ``_wrap_make_request()`` of the base middleware class. If this method is overridden, the ``make_request`` can be effectively replaced if a request is not needed. --- ens/utils.py | 10 +- tests/core/filtering/conftest.py | 2 +- tests/core/filtering/utils.py | 5 +- .../core/middleware/test_filter_middleware.py | 1 - web3/middleware/__init__.py | 28 +-- web3/middleware/abi.py | 4 +- web3/middleware/base.py | 36 +++- web3/middleware/filter.py | 174 +++++++++--------- web3/middleware/formatting.py | 16 +- web3/middleware/names.py | 4 +- .../normalize_request_parameters.py | 4 +- web3/middleware/proof_of_authority.py | 4 +- web3/middleware/pythonic.py | 4 +- web3/middleware/signing.py | 18 +- web3/middleware/stalecheck.py | 28 ++- web3/middleware/validation.py | 4 +- web3/providers/eth_tester/middleware.py | 4 +- 17 files changed, 190 insertions(+), 156 deletions(-) diff --git a/ens/utils.py b/ens/utils.py index 7b7aa33dcf..9b1bd27052 100644 --- a/ens/utils.py +++ b/ens/utils.py @@ -100,14 +100,16 @@ def init_web3( def customize_web3(w3: "_Web3") -> "_Web3": from web3.middleware import ( - StaleCheckMiddleware, + StaleCheckMiddlewareBuilder, ) if w3.middleware_onion.get("name_to_address"): w3.middleware_onion.remove("name_to_address") if not w3.middleware_onion.get("stalecheck"): - stalecheck_middleware = StaleCheckMiddleware(ACCEPTABLE_STALE_HOURS * 3600) + stalecheck_middleware = StaleCheckMiddlewareBuilder( + ACCEPTABLE_STALE_HOURS * 3600 + ) w3.middleware_onion.add(stalecheck_middleware, name="stalecheck") return w3 @@ -306,7 +308,7 @@ def init_async_web3( AsyncEth as AsyncEthMain, ) from web3.middleware import ( - StaleCheckMiddleware, + StaleCheckMiddlewareBuilder, ) middlewares = list(middlewares) @@ -316,7 +318,7 @@ def init_async_web3( if "stalecheck" not in (name for mw, name in middlewares): middlewares.append( - (StaleCheckMiddleware(ACCEPTABLE_STALE_HOURS * 3600), "stalecheck") + (StaleCheckMiddlewareBuilder(ACCEPTABLE_STALE_HOURS * 3600), "stalecheck") ) if provider is default: diff --git a/tests/core/filtering/conftest.py b/tests/core/filtering/conftest.py index 331381be2a..25a7a1bec5 100644 --- a/tests/core/filtering/conftest.py +++ b/tests/core/filtering/conftest.py @@ -67,7 +67,7 @@ def create_filter(request): @pytest.fixture( scope="function", params=[True, False], - ids=["async_local_filter_middleware", "node_based_filter"], + ids=["local_filter_middleware", "node_based_filter"], ) def async_w3(request): return _async_w3_fixture_logic(request) diff --git a/tests/core/filtering/utils.py b/tests/core/filtering/utils.py index 714505966a..b3e27036ef 100644 --- a/tests/core/filtering/utils.py +++ b/tests/core/filtering/utils.py @@ -6,7 +6,6 @@ AsyncEth, ) from web3.middleware import ( - async_local_filter_middleware, local_filter_middleware, ) from web3.providers.eth_tester import ( @@ -49,10 +48,10 @@ def _emitter_fixture_logic( def _async_w3_fixture_logic(request): use_filter_middleware = request.param provider = AsyncEthereumTesterProvider() - async_w3 = AsyncWeb3(provider, modules={"eth": [AsyncEth]}, middlewares=[]) + async_w3 = AsyncWeb3(provider) if use_filter_middleware: - async_w3.middleware_onion.add(async_local_filter_middleware) + async_w3.middleware_onion.add(local_filter_middleware) return async_w3 diff --git a/tests/core/middleware/test_filter_middleware.py b/tests/core/middleware/test_filter_middleware.py index 7ade451297..94bc3ab5ed 100644 --- a/tests/core/middleware/test_filter_middleware.py +++ b/tests/core/middleware/test_filter_middleware.py @@ -15,7 +15,6 @@ AsyncEth, ) from web3.middleware import ( - async_local_filter_middleware, attrdict_middleware, local_filter_middleware, ) diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index 2beb12e80b..c8d9373d20 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -34,7 +34,6 @@ ens_name_to_address_middleware, ) from .filter import ( - async_local_filter_middleware, local_filter_middleware, ) from .gas_price_strategy import ( @@ -48,10 +47,10 @@ ) from .signing import ( construct_sign_and_send_raw_middleware, - SignAndSendRawMiddleware, + SignAndSendRawMiddlewareBuilder, ) from .stalecheck import ( - StaleCheckMiddleware, + StaleCheckMiddlewareBuilder, make_stalecheck_middleware, ) from .validation import ( @@ -89,20 +88,9 @@ async def async_combine_middlewares( and passes these args through the request processors, makes the request, and passes the response through the response processors. """ - middlewares = [middleware(async_w3) for middleware in middlewares] - async_request_processors = [ - middleware.async_request_processor for middleware in middlewares - ] - async_response_processors = [ - middleware.async_response_processor for middleware in reversed(middlewares) - ] - - async def async_request_fn(method: RPCEndpoint, params: Any) -> RPCResponse: - for processor in async_request_processors: - method, params = await processor(method, params) - response = await provider_request_fn(method, params) - for processor in async_response_processors: - response = await processor(method, response) - return response - - return async_request_fn + accumulator_fn = provider_request_fn + for middleware in reversed(middlewares): + # initialize the middleware and wrap the accumulator function down the stack + initialized = middleware(async_w3) + accumulator_fn = await initialized._async_wrap_make_request(accumulator_fn) + return accumulator_fn diff --git a/web3/middleware/abi.py b/web3/middleware/abi.py index af8de03dfb..64b786ae75 100644 --- a/web3/middleware/abi.py +++ b/web3/middleware/abi.py @@ -2,9 +2,9 @@ ABI_REQUEST_FORMATTERS, ) from web3.middleware.formatting import ( - FormattingMiddleware, + FormattingMiddlewareBuilder, ) -abi_middleware = FormattingMiddleware.build_middleware( +abi_middleware = FormattingMiddlewareBuilder.build( request_formatters=ABI_REQUEST_FORMATTERS ) diff --git a/web3/middleware/base.py b/web3/middleware/base.py index fa8879b6d8..fbe12e58b6 100644 --- a/web3/middleware/base.py +++ b/web3/middleware/base.py @@ -1,3 +1,4 @@ +import warnings from abc import abstractmethod from typing import ( Sequence, @@ -32,6 +33,8 @@ class Web3Middleware: def __init__(self, w3: WEB3) -> None: self._w3 = w3 + # -- sync -- # + def _wrap_make_request(self, make_request): def middleware(method: "RPCEndpoint", params: Any) -> "RPCResponse": method, params = self.request_processor(method, params) @@ -39,8 +42,6 @@ def middleware(method: "RPCEndpoint", params: Any) -> "RPCResponse": return middleware - # -- sync -- # - def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: return method, params @@ -49,6 +50,16 @@ def response_processor(self, method: "RPCEndpoint", response: "RPCResponse") -> # -- async -- # + async def _async_wrap_make_request(self, make_request): + async def middleware(method: "RPCEndpoint", params: Any) -> "RPCResponse": + method, params = await self.async_request_processor(method, params) + return await self.async_response_processor( + method, + await make_request(method, params), + ) + + return middleware + async def async_request_processor( self, method: "RPCEndpoint", @@ -67,7 +78,7 @@ async def async_response_processor( class Web3MiddlewareBuilder(Web3Middleware): @staticmethod @abstractmethod - def build_middleware( + def build( w3: Union["AsyncWeb3", "Web3"], *args: Any, **kwargs: Any, @@ -82,15 +93,26 @@ def build_middleware( ```py class MyMiddleware(Web3BuilderMiddleware): + internal_property: str = None @staticmethod @curry - def build_middleware(w3, request_formatters=None, response_formatters=None): + def builder(user_provided_argument, w3): middleware = MyMiddleware(w3) - middleware.request_formatters = request_formatters - middleware.response_formatters = response_formatters - + middleware.internal_property = user_provided_argument return middleware + + def request_processor(self, method, params): + ... + + def response_processor(self, method, response): + ... + + construct_my_middleware = MyMiddleware.builder + + w3 = Web3(provider) + my_middleware = construct_my_middleware("my argument") + w3.middleware_onion.inject(my_middleware, layer=0) ``` """ raise NotImplementedError("Must be implemented by subclasses") diff --git a/web3/middleware/filter.py b/web3/middleware/filter.py index 1350a79f8c..8e30d4027d 100644 --- a/web3/middleware/filter.py +++ b/web3/middleware/filter.py @@ -5,7 +5,6 @@ Any, AsyncIterable, AsyncIterator, - Callable, Dict, Generator, Iterable, @@ -43,6 +42,9 @@ from web3._utils.rpc_abi import ( RPC, ) +from web3.middleware.base import ( + Web3Middleware, +) from web3.types import ( Coroutine, FilterParams, @@ -338,52 +340,6 @@ def block_hashes_in_range( yield getattr(w3.eth.get_block(BlockNumber(block_number)), "hash", None) -def local_filter_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" -) -> Callable[[RPCEndpoint, Any], RPCResponse]: - filters = {} - filter_id_counter = map(to_hex, itertools.count()) - - def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in NEW_FILTER_METHODS: - filter_id = next(filter_id_counter) - - _filter: Union[RequestLogs, RequestBlocks] - if method == RPC.eth_newFilter: - _filter = RequestLogs( - w3, **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) - ) - - elif method == RPC.eth_newBlockFilter: - _filter = RequestBlocks(w3) - - else: - raise NotImplementedError(method) - - filters[filter_id] = _filter - return {"result": filter_id} - - elif method in FILTER_CHANGES_METHODS: - filter_id = params[0] - # Pass through to filters not created by middleware - if filter_id not in filters: - return make_request(method, params) - _filter = filters[filter_id] - if method == RPC.eth_getFilterChanges: - return {"result": next(_filter.filter_changes)} - - elif method == RPC.eth_getFilterLogs: - # type ignored b/c logic prevents RequestBlocks which - # doesn't implement get_logs - return {"result": _filter.get_logs()} # type: ignore - else: - raise NotImplementedError(method) - else: - return make_request(method, params) - - return middleware - - # --- async --- # @@ -611,50 +567,102 @@ async def async_block_hashes_in_range( return block_hashes -async def async_local_filter_middleware( - make_request: Callable[[RPCEndpoint, Any], Any], w3: "Web3" -) -> Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]]: - filters = {} - filter_id_counter = map(to_hex, itertools.count()) +# -- middleware -- # - async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse: - if method in NEW_FILTER_METHODS: - filter_id = next(filter_id_counter) - _filter: Union[AsyncRequestLogs, AsyncRequestBlocks] - if method == RPC.eth_newFilter: - _filter = await AsyncRequestLogs( - w3, **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) - ) +class LocalFilterMiddleware(Web3Middleware): + def __init__(self, w3): + self.filters = {} + self.filter_id_counter = itertools.count() + super().__init__(w3) + + def _wrap_make_request(self, make_request): + def middleware(method, params): + if method in NEW_FILTER_METHODS: + filter_id = to_hex(next(self.filter_id_counter)) + + _filter: Union[RequestLogs, RequestBlocks] + + if method == RPC.eth_newFilter: + _filter = RequestLogs( + self._w3, **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) + ) - elif method == RPC.eth_newBlockFilter: - _filter = await AsyncRequestBlocks(w3) + elif method == RPC.eth_newBlockFilter: + _filter = RequestBlocks(self._w3) + else: + raise NotImplementedError(method) + + self.filters[filter_id] = _filter + return {"result": filter_id} + + elif method in FILTER_CHANGES_METHODS: + _filter_id = params[0] + + # Pass through to filters not created by middleware + if _filter_id not in self.filters: + return make_request(method, params) + + _filter = self.filters[_filter_id] + if method == RPC.eth_getFilterChanges: + return {"result": next(_filter.filter_changes)} + + elif method == RPC.eth_getFilterLogs: + # type ignored b/c logic prevents RequestBlocks which + # doesn't implement get_logs + return {"result": _filter.get_logs()} # type: ignore + else: + raise NotImplementedError(method) else: - raise NotImplementedError(method) + return make_request(method, params) - filters[filter_id] = _filter - return {"result": filter_id} + return middleware - elif method in FILTER_CHANGES_METHODS: - filter_id = params[0] - # Pass through to filters not created by middleware - if filter_id not in filters: - return await make_request(method, params) - _filter = filters[filter_id] + # -- async -- # + + async def _async_wrap_make_request(self, make_request): + async def middleware(method, params): + if method in NEW_FILTER_METHODS: + filter_id = to_hex(next(self.filter_id_counter)) + + _filter: Union[AsyncRequestLogs, AsyncRequestBlocks] + + if method == RPC.eth_newFilter: + _filter = await AsyncRequestLogs( + self._w3, **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) + ) + + elif method == RPC.eth_newBlockFilter: + _filter = await AsyncRequestBlocks(self._w3) + + else: + raise NotImplementedError(method) + + self.filters[filter_id] = _filter + return {"result": filter_id} + + elif method in FILTER_CHANGES_METHODS: + _filter_id = params[0] - if method == RPC.eth_getFilterChanges: - changes = await _filter.filter_changes.__anext__() - return {"result": changes} + # Pass through to filters not created by middleware + if _filter_id not in self.filters: + return await make_request(method, params) - elif method == RPC.eth_getFilterLogs: - # type ignored b/c logic prevents RequestBlocks which - # doesn't implement get_logs - logs = await _filter.get_logs() # type: ignore - return {"result": logs} + _filter = self.filters[_filter_id] + if method == RPC.eth_getFilterChanges: + return {"result": await _filter.filter_changes.__anext__()} + + elif method == RPC.eth_getFilterLogs: + # type ignored b/c logic prevents RequestBlocks which + # doesn't implement get_logs + return {"result": await _filter.get_logs()} + else: + raise NotImplementedError(method) else: - raise NotImplementedError(method) - else: - return await make_request(method, params) + return await make_request(method, params) + + return middleware + - return middleware +local_filter_middleware = LocalFilterMiddleware diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index f22f4fcdad..33987e1f78 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -92,16 +92,16 @@ def _format_response( ] -class FormattingMiddleware(Web3MiddlewareBuilder): - request_formatters = None - result_formatters = None - error_formatters = None - sync_formatters_builder = None - async_formatters_builder = None +class FormattingMiddlewareBuilder(Web3MiddlewareBuilder): + request_formatters: Formatters = None + result_formatters: Formatters = None + error_formatters: Formatters = None + sync_formatters_builder: SYNC_FORMATTERS_BUILDER = None + async_formatters_builder: ASYNC_FORMATTERS_BUILDER = None @staticmethod @curry - def build_middleware( + def build( w3: Union["AsyncWeb3", "Web3"], # formatters option: request_formatters: Optional[Formatters] = None, @@ -129,7 +129,7 @@ def build_middleware( "Cannot specify formatters_builder and formatters at the same time" ) - middleware = FormattingMiddleware(w3) + middleware = FormattingMiddlewareBuilder(w3) middleware.request_formatters = request_formatters or {} middleware.result_formatters = result_formatters or {} middleware.error_formatters = error_formatters or {} diff --git a/web3/middleware/names.py b/web3/middleware/names.py index 75239d0a12..f9e49998c3 100644 --- a/web3/middleware/names.py +++ b/web3/middleware/names.py @@ -34,7 +34,7 @@ Web3Middleware, ) from .formatting import ( - FormattingMiddleware, + FormattingMiddlewareBuilder, ) if TYPE_CHECKING: @@ -105,7 +105,7 @@ def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: normalizers = [ abi_ens_resolver(self._w3), ] - self._formatting_middleware = FormattingMiddleware.build_middleware( + self._formatting_middleware = FormattingMiddlewareBuilder.build( request_formatters=abi_request_formatters(normalizers, RPC_ABIS) # type: ignore # noqa: E501 ) diff --git a/web3/middleware/normalize_request_parameters.py b/web3/middleware/normalize_request_parameters.py index 2f4bf0d983..a18041fbd9 100644 --- a/web3/middleware/normalize_request_parameters.py +++ b/web3/middleware/normalize_request_parameters.py @@ -3,9 +3,9 @@ ) from .formatting import ( - FormattingMiddleware, + FormattingMiddlewareBuilder, ) -request_parameter_normalizer = FormattingMiddleware.build_middleware( +request_parameter_normalizer = FormattingMiddlewareBuilder.build( request_formatters=METHOD_NORMALIZERS, ) diff --git a/web3/middleware/proof_of_authority.py b/web3/middleware/proof_of_authority.py index 6e81360c18..bfb4623591 100644 --- a/web3/middleware/proof_of_authority.py +++ b/web3/middleware/proof_of_authority.py @@ -23,7 +23,7 @@ RPC, ) from web3.middleware.formatting import ( - FormattingMiddleware, + FormattingMiddlewareBuilder, ) if TYPE_CHECKING: @@ -49,7 +49,7 @@ geth_poa_cleanup = compose(pythonic_geth_poa, remap_geth_poa_fields) -extradata_to_poa_middleware = FormattingMiddleware.build_middleware( +extradata_to_poa_middleware = FormattingMiddlewareBuilder.build( result_formatters={ RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup), RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup), diff --git a/web3/middleware/pythonic.py b/web3/middleware/pythonic.py index 59b5f8a958..cb8856595b 100644 --- a/web3/middleware/pythonic.py +++ b/web3/middleware/pythonic.py @@ -3,10 +3,10 @@ PYTHONIC_RESULT_FORMATTERS, ) from web3.middleware.formatting import ( - FormattingMiddleware, + FormattingMiddlewareBuilder, ) -pythonic_middleware = FormattingMiddleware.build_middleware( +pythonic_middleware = FormattingMiddlewareBuilder.build( request_formatters=PYTHONIC_REQUEST_FORMATTERS, result_formatters=PYTHONIC_RESULT_FORMATTERS, ) diff --git a/web3/middleware/signing.py b/web3/middleware/signing.py index 268214bd56..9cd7c223c0 100644 --- a/web3/middleware/signing.py +++ b/web3/middleware/signing.py @@ -12,6 +12,8 @@ Union, ) +from toolz import curry + from eth_account import ( Account, ) @@ -53,6 +55,7 @@ ) from web3.middleware.base import ( Web3Middleware, + Web3MiddlewareBuilder, ) from web3.types import ( RPCEndpoint, @@ -139,13 +142,16 @@ def format_transaction(transaction: TxParams) -> TxParams: ) -class SignAndSendRawMiddleware(Web3Middleware): +class SignAndSendRawMiddlewareBuilder(Web3MiddlewareBuilder): + _accounts = None format_and_fill_tx = None - def __init__( - self, private_key_or_account: Union[_PrivateKey, Collection[_PrivateKey]] - ): - self._accounts = gen_normalized_accounts(private_key_or_account) + @staticmethod + @curry + def build(private_key_or_account: Union[_PrivateKey, Collection[_PrivateKey]], w3): + middleware = SignAndSendRawMiddlewareBuilder(w3) + middleware._accounts = gen_normalized_accounts(private_key_or_account) + return middleware def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method != "eth_sendTransaction": @@ -206,4 +212,4 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A ) -construct_sign_and_send_raw_middleware = SignAndSendRawMiddleware +construct_sign_and_send_raw_middleware = SignAndSendRawMiddlewareBuilder.build diff --git a/web3/middleware/stalecheck.py b/web3/middleware/stalecheck.py index 4a83850827..4f8b98391d 100644 --- a/web3/middleware/stalecheck.py +++ b/web3/middleware/stalecheck.py @@ -8,11 +8,14 @@ Optional, ) +from toolz import curry + from web3.exceptions import ( StaleBlockchain, ) from web3.middleware.base import ( Web3Middleware, + Web3MiddlewareBuilder, ) from web3.types import ( BlockData, @@ -34,20 +37,27 @@ def _is_fresh(block: BlockData, allowable_delay: int) -> bool: return False -class StaleCheckMiddleware(Web3Middleware): - def __init__( - self, +class StaleCheckMiddlewareBuilder(Web3MiddlewareBuilder): + allowable_delay: int + skip_stalecheck_for_methods: Collection[str] + cache: Dict[str, Optional[BlockData]] + + @staticmethod + @curry + def build( allowable_delay: int, + w3, skip_stalecheck_for_methods: Collection[str] = SKIP_STALECHECK_FOR_METHODS, - ) -> None: + ) -> Web3Middleware: if allowable_delay <= 0: raise ValueError( "You must set a positive allowable_delay in seconds for this middleware" ) - - self.allowable_delay = allowable_delay - self.skip_stalecheck_for_methods = skip_stalecheck_for_methods - self.cache: Dict[str, Optional[BlockData]] = {"latest": None} + middleware = StaleCheckMiddlewareBuilder(w3) + middleware.allowable_delay = allowable_delay + middleware.skip_stalecheck_for_methods = skip_stalecheck_for_methods + middleware.cache = {"latest": None} + return middleware def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method not in self.skip_stalecheck_for_methods: @@ -74,4 +84,4 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A return method, params -make_stalecheck_middleware = StaleCheckMiddleware +make_stalecheck_middleware = StaleCheckMiddlewareBuilder.build diff --git a/web3/middleware/validation.py b/web3/middleware/validation.py index 5551ac0a3d..a8660e5e8f 100644 --- a/web3/middleware/validation.py +++ b/web3/middleware/validation.py @@ -33,7 +33,7 @@ Web3ValidationError, ) from web3.middleware.formatting import ( - FormattingMiddleware, + FormattingMiddlewareBuilder, ) from web3.types import ( Formatters, @@ -160,7 +160,7 @@ async def async_build_method_validators( return _build_formatters_dict(request_formatters) -validation_middleware = FormattingMiddleware.build_middleware( +validation_middleware = FormattingMiddlewareBuilder.build( sync_formatters_builder=build_method_validators, async_formatters_builder=async_build_method_validators, ) diff --git a/web3/providers/eth_tester/middleware.py b/web3/providers/eth_tester/middleware.py index 1fa133e889..453a62d081 100644 --- a/web3/providers/eth_tester/middleware.py +++ b/web3/providers/eth_tester/middleware.py @@ -44,7 +44,7 @@ Web3Middleware, ) from web3.middleware.formatting import ( - FormattingMiddleware, + FormattingMiddlewareBuilder, ) from web3.types import ( RPCEndpoint, @@ -400,7 +400,7 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A return method, params -ethereum_tester_middleware = FormattingMiddleware.build_middleware( +ethereum_tester_middleware = FormattingMiddlewareBuilder.build( request_formatters=request_formatters, result_formatters=result_formatters ) default_transaction_fields_middleware = DefaultTransactionFieldsMiddleware From 806eb73cd5e8d4e323d3d49ba98c5785605cc701 Mon Sep 17 00:00:00 2001 From: fselmo Date: Fri, 8 Dec 2023 14:53:31 -0700 Subject: [PATCH 11/24] Fix remaining tests that were broken due to refactor --- ens/utils.py | 8 +- tests/core/contracts/test_contract_example.py | 8 +- tests/core/eth-module/test_transactions.py | 28 +++-- .../test_time_based_gas_price_strategy.py | 67 ++++++------ tests/core/manager/conftest.py | 22 ++-- .../core/manager/test_default_middlewares.py | 22 +--- .../test_middleware_can_be_stateful.py | 18 ++-- .../middleware/test_eth_tester_middleware.py | 10 +- .../core/middleware/test_filter_middleware.py | 69 +++++------- .../middleware/test_formatting_middleware.py | 101 ++++++++---------- .../middleware/test_gas_price_strategy.py | 99 +++++++---------- tests/core/middleware/test_stalecheck.py | 92 ++++++++++------ .../middleware/test_transaction_signing.py | 15 ++- tests/core/providers/test_http_provider.py | 8 +- tests/ens/test_ens.py | 4 +- tests/ens/test_utils.py | 8 +- web3/_utils/module_testing/utils.py | 10 +- web3/middleware/__init__.py | 3 + web3/middleware/formatting.py | 3 + 19 files changed, 294 insertions(+), 301 deletions(-) diff --git a/ens/utils.py b/ens/utils.py index 9b1bd27052..8e99c81018 100644 --- a/ens/utils.py +++ b/ens/utils.py @@ -100,14 +100,14 @@ def init_web3( def customize_web3(w3: "_Web3") -> "_Web3": from web3.middleware import ( - StaleCheckMiddlewareBuilder, + make_stalecheck_middleware, ) if w3.middleware_onion.get("name_to_address"): w3.middleware_onion.remove("name_to_address") if not w3.middleware_onion.get("stalecheck"): - stalecheck_middleware = StaleCheckMiddlewareBuilder( + stalecheck_middleware = make_stalecheck_middleware( ACCEPTABLE_STALE_HOURS * 3600 ) w3.middleware_onion.add(stalecheck_middleware, name="stalecheck") @@ -308,7 +308,7 @@ def init_async_web3( AsyncEth as AsyncEthMain, ) from web3.middleware import ( - StaleCheckMiddlewareBuilder, + make_stalecheck_middleware, ) middlewares = list(middlewares) @@ -318,7 +318,7 @@ def init_async_web3( if "stalecheck" not in (name for mw, name in middlewares): middlewares.append( - (StaleCheckMiddlewareBuilder(ACCEPTABLE_STALE_HOURS * 3600), "stalecheck") + (make_stalecheck_middleware(ACCEPTABLE_STALE_HOURS * 3600), "stalecheck") ) if provider is default: diff --git a/tests/core/contracts/test_contract_example.py b/tests/core/contracts/test_contract_example.py index 7ed27390e9..28a52cba80 100644 --- a/tests/core/contracts/test_contract_example.py +++ b/tests/core/contracts/test_contract_example.py @@ -5,6 +5,7 @@ import pytest_asyncio from web3 import ( + AsyncWeb3, EthereumTesterProvider, Web3, ) @@ -121,10 +122,9 @@ def async_eth_tester(): @pytest_asyncio.fixture() async def async_w3(): - provider = AsyncEthereumTesterProvider() - w3 = Web3(provider, modules={"eth": [AsyncEth]}, middlewares=provider.middlewares) - w3.eth.default_account = await w3.eth.coinbase - return w3 + async_w3 = AsyncWeb3(AsyncEthereumTesterProvider()) + async_w3.eth.default_account = await async_w3.eth.coinbase + return async_w3 @pytest_asyncio.fixture() diff --git a/tests/core/eth-module/test_transactions.py b/tests/core/eth-module/test_transactions.py index 1fa2b78e16..dbb668c40a 100644 --- a/tests/core/eth-module/test_transactions.py +++ b/tests/core/eth-module/test_transactions.py @@ -1,3 +1,6 @@ +import collections +import itertools + import pytest from eth_utils import ( @@ -170,8 +173,9 @@ def test_passing_string_to_to_hex(w3): w3.eth.wait_for_transaction_receipt(transaction_hash, timeout=RECEIPT_TIMEOUT) -def test_unmined_transaction_wait_for_receipt(w3): - # w3.middleware_onion.add(unmined_receipt_simulator_middleware) +def test_unmined_transaction_wait_for_receipt(w3, request_mocker): + receipt_counters = collections.defaultdict(itertools.count) + txn_hash = w3.eth.send_transaction( { "from": w3.eth.coinbase, @@ -179,12 +183,22 @@ def test_unmined_transaction_wait_for_receipt(w3): "value": 123457, } ) - with pytest.raises(TransactionNotFound): - w3.eth.get_transaction_receipt(txn_hash) + unmocked_make_request = w3.provider.make_request + + with request_mocker( + w3, + mock_results={ + RPC.eth_getTransactionReceipt: lambda method, params: None + if next(receipt_counters[params[0]]) < 5 + else unmocked_make_request(method, params)["result"] + }, + ): + with pytest.raises(TransactionNotFound): + w3.eth.get_transaction_receipt(txn_hash) - txn_receipt = w3.eth.wait_for_transaction_receipt(txn_hash) - assert txn_receipt["transactionHash"] == txn_hash - assert txn_receipt["blockHash"] is not None + txn_receipt = w3.eth.wait_for_transaction_receipt(txn_hash) + assert txn_receipt["transactionHash"] == txn_hash + assert txn_receipt["blockHash"] is not None def test_get_transaction_formatters(w3, request_mocker): diff --git a/tests/core/gas-strategies/test_time_based_gas_price_strategy.py b/tests/core/gas-strategies/test_time_based_gas_price_strategy.py index 20ac7b5a07..c94e4c1fdb 100644 --- a/tests/core/gas-strategies/test_time_based_gas_price_strategy.py +++ b/tests/core/gas-strategies/test_time_based_gas_price_strategy.py @@ -147,25 +147,22 @@ def _get_block_by_something(method, params): (dict(max_wait_seconds=80, sample_size=5, probability=50, weighted=True), 11), ), ) -def test_time_based_gas_price_strategy(strategy_params, expected): - # fixture_middleware = construct_result_generator_middleware( - # { - # "eth_getBlockByHash": _get_block_by_something, - # "eth_getBlockByNumber": _get_block_by_something, - # } - # ) - - w3 = Web3( - provider=BaseProvider(), - # middlewares=[fixture_middleware], - ) +def test_time_based_gas_price_strategy(strategy_params, expected, request_mocker): + w3 = Web3(provider=BaseProvider()) time_based_gas_price_strategy = construct_time_based_gas_price_strategy( **strategy_params, ) w3.eth.set_gas_price_strategy(time_based_gas_price_strategy) - actual = w3.eth.generate_gas_price() - assert actual == expected + with request_mocker( + w3, + mock_results={ + "eth_getBlockByHash": _get_block_by_something, + "eth_getBlockByNumber": _get_block_by_something, + }, + ): + actual = w3.eth.generate_gas_price() + assert actual == expected def _get_initial_block(method, params): @@ -183,7 +180,7 @@ def _get_gas_price(method, params): return 4321 -def test_time_based_gas_price_strategy_without_transactions(): +def test_time_based_gas_price_strategy_without_transactions(request_mocker): # fixture_middleware = construct_result_generator_middleware( # { # "eth_getBlockByHash": _get_initial_block, @@ -192,10 +189,7 @@ def test_time_based_gas_price_strategy_without_transactions(): # } # ) - w3 = Web3( - provider=BaseProvider(), - # middlewares=[fixture_middleware], - ) + w3 = Web3(provider=BaseProvider()) time_based_gas_price_strategy = construct_time_based_gas_price_strategy( max_wait_seconds=80, @@ -204,8 +198,16 @@ def test_time_based_gas_price_strategy_without_transactions(): weighted=True, ) w3.eth.set_gas_price_strategy(time_based_gas_price_strategy) - actual = w3.eth.generate_gas_price() - assert actual == w3.eth.gas_price + with request_mocker( + w3, + mock_results={ + "eth_getBlockByHash": _get_initial_block, + "eth_getBlockByNumber": _get_initial_block, + "eth_gasPrice": _get_gas_price, + }, + ): + actual = w3.eth.generate_gas_price() + assert actual == w3.eth.gas_price @pytest.mark.parametrize( @@ -266,23 +268,20 @@ def test_time_based_gas_price_strategy_without_transactions(): ), ) def test_time_based_gas_price_strategy_zero_sample( - strategy_params_zero, expected_exception_message + strategy_params_zero, expected_exception_message, request_mocker ): with pytest.raises(Web3ValidationError) as excinfo: - # fixture_middleware = construct_result_generator_middleware( - # { - # "eth_getBlockByHash": _get_block_by_something, - # "eth_getBlockByNumber": _get_block_by_something, - # } - # ) - - w3 = Web3( - provider=BaseProvider(), - # middlewares=[fixture_middleware], - ) + w3 = Web3(provider=BaseProvider()) time_based_gas_price_strategy_zero = construct_time_based_gas_price_strategy( **strategy_params_zero, ) w3.eth.set_gas_price_strategy(time_based_gas_price_strategy_zero) - w3.eth.generate_gas_price() + with request_mocker( + w3, + mock_results={ + "eth_getBlockByHash": _get_block_by_something, + "eth_getBlockByNumber": _get_block_by_something, + }, + ): + w3.eth.generate_gas_price() assert str(excinfo.value) == expected_exception_message diff --git a/tests/core/manager/conftest.py b/tests/core/manager/conftest.py index 458204dcd5..d522e3d438 100644 --- a/tests/core/manager/conftest.py +++ b/tests/core/manager/conftest.py @@ -1,6 +1,8 @@ import itertools import pytest +from web3.middleware import Web3Middleware + @pytest.fixture def middleware_factory(): @@ -15,15 +17,19 @@ class Wrapper: def __repr__(self): return "middleware-" + key - def __call__(self, make_request, w3): - def middleware_fn(method, params): - params.append(key) - method = "|".join((method, key)) - response = make_request(method, params) - response["result"]["middlewares"].append(key) - return response + def __call__(self, web3): + class Middleware(Web3Middleware): + def _wrap_make_request(self, make_request): + def middleware_fn(method, params): + params.append(key) + method = "|".join((method, key)) + response = make_request(method, params) + response["result"]["middlewares"].append(key) + return response + + return middleware_fn - return middleware_fn + return Middleware(web3) return Wrapper() diff --git a/tests/core/manager/test_default_middlewares.py b/tests/core/manager/test_default_middlewares.py index bda6c1d3e8..b55815058c 100644 --- a/tests/core/manager/test_default_middlewares.py +++ b/tests/core/manager/test_default_middlewares.py @@ -2,7 +2,6 @@ RequestManager, ) from web3.middleware import ( - abi_middleware, attrdict_middleware, buffered_gas_estimate_middleware, ens_name_to_address_middleware, @@ -17,28 +16,9 @@ def test_default_sync_middlewares(w3): (ens_name_to_address_middleware, "name_to_address"), (attrdict_middleware, "attrdict"), (validation_middleware, "validation"), - (abi_middleware, "abi"), (buffered_gas_estimate_middleware, "gas_estimate"), ] default_middlewares = RequestManager.default_middlewares(w3) - for x in range(len(default_middlewares)): - assert default_middlewares[x][0].__name__ == expected_middlewares[x][0].__name__ - assert default_middlewares[x][1] == expected_middlewares[x][1] - - -def test_default_async_middlewares(): - expected_middlewares = [ - (gas_price_strategy_middleware, "gas_price_strategy"), - (ens_name_to_address_middleware, "name_to_address"), - (attrdict_middleware, "attrdict"), - (validation_middleware, "validation"), - (buffered_gas_estimate_middleware, "gas_estimate"), - ] - - default_middlewares = RequestManager.async_default_middlewares() - - for x in range(len(default_middlewares)): - assert default_middlewares[x][0].__name__ == expected_middlewares[x][0].__name__ - assert default_middlewares[x][1] == expected_middlewares[x][1] + assert default_middlewares == expected_middlewares diff --git a/tests/core/manager/test_middleware_can_be_stateful.py b/tests/core/manager/test_middleware_can_be_stateful.py index b666dc25a0..e8613b49c6 100644 --- a/tests/core/manager/test_middleware_can_be_stateful.py +++ b/tests/core/manager/test_middleware_can_be_stateful.py @@ -1,20 +1,26 @@ from web3.manager import ( RequestManager, ) +from web3.middleware import ( + Web3Middleware, +) from web3.providers import ( BaseProvider, ) -def stateful_middleware(make_request, w3): +class StatefulMiddleware(Web3Middleware): state = [] - def middleware(method, params): - state.append((method, params)) - return {"result": state} + def _wrap_make_request(self, make_request): + def middleware(method, params): + self.state.append((method, params)) + return {"result": self.state} + + return middleware + - middleware.state = state - return middleware +stateful_middleware = StatefulMiddleware def test_middleware_holds_state_across_requests(): diff --git a/tests/core/middleware/test_eth_tester_middleware.py b/tests/core/middleware/test_eth_tester_middleware.py index 819376daa0..7d8d085e71 100644 --- a/tests/core/middleware/test_eth_tester_middleware.py +++ b/tests/core/middleware/test_eth_tester_middleware.py @@ -92,9 +92,10 @@ def mock_request(_method, params): mock_w3.eth.accounts = w3_accounts mock_w3.eth.coinbase = w3_coinbase - middleware = default_transaction_fields_middleware(mock_request, mock_w3) + middleware = default_transaction_fields_middleware(mock_w3) base_params = {"chainId": 5} - filled_transaction = middleware(method, [base_params]) + inner = middleware._wrap_make_request(mock_request) + filled_transaction = inner(method, [base_params]) filled_params = filled_transaction[0] @@ -177,9 +178,10 @@ async def mock_async_coinbase(): mock_w3.eth.accounts = mock_async_accounts() mock_w3.eth.coinbase = mock_async_coinbase() - middleware = default_transaction_fields_middleware(mock_request, mock_w3) + middleware = default_transaction_fields_middleware(mock_w3) base_params = {"chainId": 5} - filled_transaction = await middleware(method, [base_params]) + inner = await middleware._async_wrap_make_request(mock_request) + filled_transaction = await inner(method, [base_params]) filled_params = filled_transaction[0] assert ("from" in filled_params.keys()) == from_field_added diff --git a/tests/core/middleware/test_filter_middleware.py b/tests/core/middleware/test_filter_middleware.py index 94bc3ab5ed..3f7ad21583 100644 --- a/tests/core/middleware/test_filter_middleware.py +++ b/tests/core/middleware/test_filter_middleware.py @@ -6,6 +6,7 @@ import pytest_asyncio from web3 import ( + AsyncWeb3, Web3, ) from web3.datastructures import ( @@ -84,29 +85,20 @@ def iterator(): @pytest.fixture(scope="function") -def result_generator_middleware(iter_block_number): - return None - # return construct_result_generator_middleware( - # { - # "eth_getLogs": lambda *_: FILTER_LOG, - # "eth_getBlockByNumber": lambda *_: {"hash": BLOCK_HASH}, - # "net_version": lambda *_: 1, - # "eth_blockNumber": lambda *_: next(iter_block_number), - # } - # ) - - -@pytest.fixture(scope="function") -def w3_base(): - return Web3(provider=DummyProvider(), middlewares=[]) - - -@pytest.fixture(scope="function") -def w3(w3_base, result_generator_middleware): - w3_base.middleware_onion.add(result_generator_middleware) +def w3(request_mocker, iter_block_number): + w3_base = Web3(provider=DummyProvider(), middlewares=[]) w3_base.middleware_onion.add(attrdict_middleware) w3_base.middleware_onion.add(local_filter_middleware) - return w3_base + with request_mocker( + w3_base, + mock_results={ + "eth_getLogs": lambda *_: FILTER_LOG, + "eth_getBlockByNumber": lambda *_: {"hash": BLOCK_HASH}, + "net_version": lambda *_: 1, + "eth_blockNumber": lambda *_: next(iter_block_number), + }, + ): + yield w3_base @pytest.mark.parametrize( @@ -263,30 +255,21 @@ async def make_request(self, method, params): @pytest_asyncio.fixture(scope="function") -async def async_result_generator_middleware(iter_block_number): - return None - # return await async_construct_result_generator_middleware( - # { - # "eth_getLogs": lambda *_: FILTER_LOG, - # "eth_getBlockByNumber": lambda *_: {"hash": BLOCK_HASH}, - # "net_version": lambda *_: 1, - # "eth_blockNumber": lambda *_: next(iter_block_number), - # } - # ) - - -@pytest.fixture(scope="function") -def async_w3_base(): - return Web3( - provider=AsyncDummyProvider(), modules={"eth": (AsyncEth)}, middlewares=[] - ) - - -@pytest.fixture(scope="function") -def async_w3(async_w3_base, async_result_generator_middleware): +async def async_w3(request_mocker, iter_block_number): + async_w3_base = AsyncWeb3(provider=AsyncDummyProvider(), middlewares=[]) async_w3_base.middleware_onion.add(attrdict_middleware) async_w3_base.middleware_onion.add(local_filter_middleware) - return async_w3_base + + async with request_mocker( + async_w3_base, + mock_results={ + "eth_getLogs": lambda *_: FILTER_LOG, + "eth_getBlockByNumber": lambda *_: {"hash": BLOCK_HASH}, + "net_version": lambda *_: 1, + "eth_blockNumber": lambda *_: next(iter_block_number), + }, + ): + yield async_w3_base @pytest.mark.parametrize( diff --git a/tests/core/middleware/test_formatting_middleware.py b/tests/core/middleware/test_formatting_middleware.py index 73f4f34a0e..1bd5cec913 100644 --- a/tests/core/middleware/test_formatting_middleware.py +++ b/tests/core/middleware/test_formatting_middleware.py @@ -6,12 +6,13 @@ from web3 import ( Web3, ) +from web3.middleware import ( + construct_formatting_middleware, +) from web3.providers.base import ( BaseProvider, ) -from web3.types import ( - RPCEndpoint, -) +from web3.types import RPCEndpoint class DummyProvider(BaseProvider): @@ -28,86 +29,70 @@ def test_formatting_middleware(w3, request_mocker): # No formatters by default expected = "done" with request_mocker(w3, mock_results={"test_endpoint": "done"}): - actual = w3.provider.make_request("test_endpoint", []) + actual = w3.manager.request_blocking(RPCEndpoint("test_endpoint"), []) assert actual == expected def test_formatting_middleware_no_method(w3): - # w3.middleware_onion.add(construct_formatting_middleware()) + w3.middleware_onion.add(construct_formatting_middleware()) # Formatting middleware requires an endpoint with pytest.raises(NotImplementedError): w3.manager.request_blocking("test_endpoint", []) -def test_formatting_middleware_request_formatters(w3): +def test_formatting_middleware_request_formatters(w3, request_mocker): callable_mock = Mock() - # w3.middleware_onion.add( - # construct_result_generator_middleware( - # {RPCEndpoint("test_endpoint"): lambda method, params: "done"} - # ) - # ) - # - # w3.middleware_onion.add( - # construct_formatting_middleware( - # request_formatters={RPCEndpoint("test_endpoint"): callable_mock} - # ) - # ) + w3.middleware_onion.add( + construct_formatting_middleware( + request_formatters={"test_endpoint": callable_mock} + ) + ) expected = "done" - actual = w3.manager.request_blocking("test_endpoint", ["param1"]) + with request_mocker(w3, mock_results={"test_endpoint": "done"}): + actual = w3.manager.request_blocking("test_endpoint", ["param1"]) callable_mock.assert_called_once_with(["param1"]) assert actual == expected -def test_formatting_middleware_result_formatters(w3): - # w3.middleware_onion.add( - # construct_result_generator_middleware( - # {RPCEndpoint("test_endpoint"): lambda method, params: "done"} - # ) - # ) - # w3.middleware_onion.add( - # construct_formatting_middleware( - # result_formatters={RPCEndpoint("test_endpoint"): lambda x: f"STATUS:{x}"} - # ) - # ) - - expected = "STATUS:done" - actual = w3.manager.request_blocking("test_endpoint", []) +def test_formatting_middleware_result_formatters(w3, request_mocker): + w3.middleware_onion.add( + construct_formatting_middleware( + result_formatters={"test_endpoint": lambda x: f"STATUS: {x}"} + ) + ) + + expected = "STATUS: done" + with request_mocker(w3, mock_results={"test_endpoint": "done"}): + actual = w3.manager.request_blocking("test_endpoint", []) + assert actual == expected -def test_formatting_middleware_result_formatters_for_none(w3): - # w3.middleware_onion.add( - # construct_result_generator_middleware( - # {RPCEndpoint("test_endpoint"): lambda method, params: None} - # ) - # ) - # w3.middleware_onion.add( - # construct_formatting_middleware( - # result_formatters={RPCEndpoint("test_endpoint"): lambda x: hex(x)} - # ) - # ) +def test_formatting_middleware_result_formatters_for_none(w3, request_mocker): + w3.middleware_onion.add( + construct_formatting_middleware( + result_formatters={"test_endpoint": lambda x: hex(x)} + ) + ) expected = None - actual = w3.manager.request_blocking("test_endpoint", []) + with request_mocker(w3, mock_results={"test_endpoint": expected}): + actual = w3.manager.request_blocking("test_endpoint", []) assert actual == expected -def test_formatting_middleware_error_formatters(w3): - # w3.middleware_onion.add( - # construct_error_generator_middleware( - # {RPCEndpoint("test_endpoint"): lambda method, params: "error"} - # ) - # ) - # w3.middleware_onion.add( - # construct_formatting_middleware( - # result_formatters={RPCEndpoint("test_endpoint"): lambda x: f"STATUS:{x}"} - # ) - # ) +def test_formatting_middleware_error_formatters(w3, request_mocker): + w3.middleware_onion.add( + construct_formatting_middleware( + result_formatters={"test_endpoint": lambda x: f"STATUS: {x}"} + ) + ) expected = "error" - with pytest.raises(ValueError) as err: - w3.manager.request_blocking("test_endpoint", []) - assert str(err.value) == expected + with request_mocker(w3, mock_errors={"test_endpoint": {"message": "error"}}): + with pytest.raises(ValueError) as err: + w3.manager.request_blocking("test_endpoint", []) + assert str(err.value) == expected diff --git a/tests/core/middleware/test_gas_price_strategy.py b/tests/core/middleware/test_gas_price_strategy.py index c5193d41c2..54a993cb6a 100644 --- a/tests/core/middleware/test_gas_price_strategy.py +++ b/tests/core/middleware/test_gas_price_strategy.py @@ -3,6 +3,8 @@ Mock, ) +from toolz import merge + from web3.middleware import ( gas_price_strategy_middleware, ) @@ -10,84 +12,57 @@ @pytest.fixture def the_gas_price_strategy_middleware(w3): - make_request, w3 = Mock(), Mock() - initialized = gas_price_strategy_middleware(make_request, w3) - initialized.w3 = w3 - initialized.make_request = make_request + w3 = Mock() + initialized = gas_price_strategy_middleware(w3) return initialized def test_gas_price_generated(the_gas_price_strategy_middleware): - the_gas_price_strategy_middleware.w3.eth.generate_gas_price.return_value = 5 - method = "eth_sendTransaction" - params = ( - { - "to": "0x0", - "value": 1, - }, - ) - the_gas_price_strategy_middleware(method, params) - the_gas_price_strategy_middleware.w3.eth.generate_gas_price.assert_called_once_with( - { - "to": "0x0", - "value": 1, - } + the_gas_price_strategy_middleware._w3.eth.generate_gas_price.return_value = 5 + + make_request = Mock() + inner = the_gas_price_strategy_middleware._wrap_make_request(make_request) + method, dict_param = "eth_sendTransaction", {"to": "0x0", "value": 1} + inner(method, (dict_param,)) + + the_gas_price_strategy_middleware._w3.eth.generate_gas_price.assert_called_once_with( + dict_param ) - the_gas_price_strategy_middleware.make_request.assert_called_once_with( - method, - ( - { - "to": "0x0", - "value": 1, - "gasPrice": "0x5", - }, - ), + make_request.assert_called_once_with( + method, (merge(dict_param, {"gasPrice": "0x5"}),) ) def test_gas_price_not_overridden(the_gas_price_strategy_middleware): - the_gas_price_strategy_middleware.w3.eth.generate_gas_price.return_value = 5 - method = "eth_sendTransaction" - params = ( - { - "to": "0x0", - "value": 1, - "gasPrice": 10, - }, - ) - the_gas_price_strategy_middleware(method, params) - the_gas_price_strategy_middleware.make_request.assert_called_once_with( - method, - ( - { - "to": "0x0", - "value": 1, - "gasPrice": 10, - }, - ), - ) + the_gas_price_strategy_middleware._w3.eth.generate_gas_price.return_value = 5 + + make_request = Mock() + inner = the_gas_price_strategy_middleware._wrap_make_request(make_request) + method, params = "eth_sendTransaction", ({"to": "0x0", "value": 1, "gasPrice": 10},) + inner(method, params) + + make_request.assert_called_once_with(method, params) def test_gas_price_not_set_without_gas_price_strategy( the_gas_price_strategy_middleware, ): - the_gas_price_strategy_middleware.w3.eth.generate_gas_price.return_value = None - method = "eth_sendTransaction" - params = ( - { - "to": "0x0", - "value": 1, - }, - ) - the_gas_price_strategy_middleware(method, params) - the_gas_price_strategy_middleware.make_request.assert_called_once_with( - method, params - ) + the_gas_price_strategy_middleware._w3.eth.generate_gas_price.return_value = None + + make_request = Mock() + inner = the_gas_price_strategy_middleware._wrap_make_request(make_request) + method, params = "eth_sendTransaction", ({"to": "0x0", "value": 1},) + inner(method, params) + + make_request.assert_called_once_with(method, params) def test_not_generate_gas_price_when_not_send_transaction_rpc( the_gas_price_strategy_middleware, ): - the_gas_price_strategy_middleware.w3.getGasPriceStrategy = Mock() - the_gas_price_strategy_middleware("eth_getBalance", []) - the_gas_price_strategy_middleware.w3.getGasPriceStrategy.assert_not_called() + the_gas_price_strategy_middleware._w3.get_gas_price_strategy = Mock() + + inner = the_gas_price_strategy_middleware._wrap_make_request(Mock()) + inner("eth_getBalance", []) + + the_gas_price_strategy_middleware._w3.get_gas_price_strategy.assert_not_called() diff --git a/tests/core/middleware/test_stalecheck.py b/tests/core/middleware/test_stalecheck.py index 36c4ba6e3d..a722cec56a 100644 --- a/tests/core/middleware/test_stalecheck.py +++ b/tests/core/middleware/test_stalecheck.py @@ -31,10 +31,9 @@ def allowable_delay(): @pytest.fixture def request_middleware(allowable_delay): - middleware = make_stalecheck_middleware(allowable_delay) web3 = Mock() - middleware._web3 = web3 - middleware._web3.provider.make_request = Mock() + middleware = make_stalecheck_middleware(allowable_delay, web3) + middleware._w3.provider.make_request = Mock() return middleware @@ -67,16 +66,23 @@ def test_is_fresh(now): def test_stalecheck_pass(request_middleware): with patch("web3.middleware.stalecheck._is_fresh", return_value=True): + make_request = Mock() + inner = request_middleware._wrap_make_request(make_request) + method, params = object(), object() - request_middleware(method, params) - request_middleware.make_request.assert_called_once_with(method, params) + inner(method, params) + + make_request.assert_called_once_with(method, params) def test_stalecheck_fail(request_middleware, now): with patch("web3.middleware.stalecheck._is_fresh", return_value=False): - request_middleware.web3.eth.get_block.return_value = stub_block(now) + request_middleware._w3.eth.get_block.return_value = stub_block(now) + + response = object() + inner = request_middleware._wrap_make_request(lambda *_: response) with pytest.raises(StaleBlockchain): - request_middleware("", []) + inner("", []) @pytest.mark.parametrize( @@ -89,8 +95,9 @@ def test_stalecheck_ignores_get_by_block_methods(request_middleware, rpc_method) # This is especially critical for get_block('latest') # which would cause infinite recursion with patch("web3.middleware.stalecheck._is_fresh", side_effect=[False, True]): - request_middleware(rpc_method, []) - assert not request_middleware.web3.eth.get_block.called + inner = request_middleware._wrap_make_request(lambda *_: None) + inner(rpc_method, []) + assert not request_middleware._w3.eth.get_block.called def test_stalecheck_calls_is_fresh_with_empty_cache( @@ -100,8 +107,9 @@ def test_stalecheck_calls_is_fresh_with_empty_cache( "web3.middleware.stalecheck._is_fresh", side_effect=[False, True] ) as fresh_spy: block = object() - request_middleware.web3.eth.get_block.return_value = block - request_middleware("", []) + request_middleware._w3.eth.get_block.return_value = block + inner = request_middleware._wrap_make_request(lambda *_: None) + inner("", []) cache_call, live_call = fresh_spy.call_args_list assert cache_call[0] == (None, allowable_delay) assert live_call[0] == (block, allowable_delay) @@ -112,17 +120,18 @@ def test_stalecheck_adds_block_to_cache(request_middleware, allowable_delay): "web3.middleware.stalecheck._is_fresh", side_effect=[False, True, True] ) as fresh_spy: block = object() - request_middleware.web3.eth.get_block.return_value = block + request_middleware._w3.eth.get_block.return_value = block # cache miss - request_middleware("", []) + inner = request_middleware._wrap_make_request(lambda *_: None) + inner("", []) cache_call, live_call = fresh_spy.call_args_list assert fresh_spy.call_count == 2 assert cache_call == ((None, allowable_delay),) assert live_call == ((block, allowable_delay),) # cache hit - request_middleware("", []) + inner("", []) assert fresh_spy.call_count == 3 assert fresh_spy.call_args == ((block, allowable_delay),) @@ -135,61 +144,74 @@ def test_stalecheck_adds_block_to_cache(request_middleware, allowable_delay): ) +async def _coro(_method, _params): + return None + + @pytest_asyncio.fixture -async def request_async_middleware(allowable_delay): +async def async_request_middleware(allowable_delay): from unittest.mock import ( AsyncMock, ) - middleware = make_stalecheck_middleware(allowable_delay) async_web3 = AsyncMock() - # for easier mocking, later: - middleware._web3 = async_web3 + middleware = make_stalecheck_middleware(allowable_delay, async_web3) + middleware._w3.provider.make_request = Mock() return middleware @pytest.mark.asyncio @min_version -async def test_async_stalecheck_pass(request_async_middleware): +async def test_async_stalecheck_pass(async_request_middleware): + from unittest.mock import AsyncMock + with patch("web3.middleware.stalecheck._is_fresh", return_value=True): + make_request = AsyncMock() + inner = await async_request_middleware._async_wrap_make_request(make_request) + method, params = object(), object() - await request_async_middleware(method, params) - request_async_middleware.make_request.assert_called_once_with(method, params) + await inner(method, params) + + make_request.assert_called_once_with(method, params) @pytest.mark.asyncio @min_version -async def test_async_stalecheck_fail(request_async_middleware, now): +async def test_async_stalecheck_fail(async_request_middleware, now): with patch("web3.middleware.stalecheck._is_fresh", return_value=False): - request_async_middleware.web3.eth.get_block.return_value = stub_block(now) + async_request_middleware._w3.eth.get_block.return_value = stub_block(now) + with pytest.raises(StaleBlockchain): - await request_async_middleware("", []) + inner = await async_request_middleware._async_wrap_make_request(_coro) + await inner("", []) @pytest.mark.asyncio @pytest.mark.parametrize("rpc_method", ["eth_getBlockByNumber"]) @min_version async def test_async_stalecheck_ignores_get_by_block_methods( - request_async_middleware, rpc_method + async_request_middleware, rpc_method ): # This is especially critical for get_block("latest") which would cause # infinite recursion with patch("web3.middleware.stalecheck._is_fresh", side_effect=[False, True]): - await request_async_middleware(rpc_method, []) - assert not request_async_middleware.web3.eth.get_block.called + inner = await async_request_middleware._async_wrap_make_request(_coro) + await inner(rpc_method, []) + assert not async_request_middleware._w3.eth.get_block.called @pytest.mark.asyncio @min_version async def test_async_stalecheck_calls_is_fresh_with_empty_cache( - request_async_middleware, allowable_delay + async_request_middleware, allowable_delay ): with patch( "web3.middleware.stalecheck._is_fresh", side_effect=[False, True] ) as fresh_spy: block = object() - request_async_middleware.web3.eth.get_block.return_value = block - await request_async_middleware("", []) + async_request_middleware._w3.eth.get_block.return_value = block + inner = await async_request_middleware._async_wrap_make_request(_coro) + await inner("", []) cache_call, live_call = fresh_spy.call_args_list assert cache_call[0] == (None, allowable_delay) assert live_call[0] == (block, allowable_delay) @@ -198,22 +220,24 @@ async def test_async_stalecheck_calls_is_fresh_with_empty_cache( @pytest.mark.asyncio @min_version async def test_async_stalecheck_adds_block_to_cache( - request_async_middleware, allowable_delay + async_request_middleware, allowable_delay ): with patch( "web3.middleware.stalecheck._is_fresh", side_effect=[False, True, True] ) as fresh_spy: block = object() - request_async_middleware.web3.eth.get_block.return_value = block + async_request_middleware._w3.eth.get_block.return_value = block + + inner = await async_request_middleware._async_wrap_make_request(_coro) # cache miss - await request_async_middleware("", []) + await inner("", []) cache_call, live_call = fresh_spy.call_args_list assert fresh_spy.call_count == 2 assert cache_call == ((None, allowable_delay),) assert live_call == ((block, allowable_delay),) # cache hit - await request_async_middleware("", []) + await inner("", []) assert fresh_spy.call_count == 3 assert fresh_spy.call_args == ((block, allowable_delay),) diff --git a/tests/core/middleware/test_transaction_signing.py b/tests/core/middleware/test_transaction_signing.py index 506d533e0b..b0297ebc41 100644 --- a/tests/core/middleware/test_transaction_signing.py +++ b/tests/core/middleware/test_transaction_signing.py @@ -416,9 +416,18 @@ async def coro_request(self, method, params): raise NotImplementedError(f"Cannot make request for {method}:{params}") -@pytest.fixture -def async_w3_dummy(): - return AsyncWeb3(provider=AsyncDummyProvider(), middlewares=[]) +@pytest_asyncio.fixture +async def async_w3_dummy(request_mocker): + w3_base = AsyncWeb3(provider=AsyncDummyProvider(), middlewares=[]) + async with request_mocker( + w3_base, + mock_results={ + "eth_sendRawTransaction": lambda *args: args, + "net_version": 1, + "eth_chainId": "0x02", + }, + ): + yield w3_base @pytest.fixture diff --git a/tests/core/providers/test_http_provider.py b/tests/core/providers/test_http_provider.py index 919defd0b6..245c8f35d9 100644 --- a/tests/core/providers/test_http_provider.py +++ b/tests/core/providers/test_http_provider.py @@ -80,19 +80,15 @@ def test_web3_with_http_provider_has_default_middlewares_and_modules() -> None: # the following length check should fail and will need to be added to once more # middlewares are added to the defaults - assert len(w3.middleware_onion.middlewares) == 6 + assert len(w3.middleware_onion.middlewares) == 5 assert ( w3.middleware_onion.get("gas_price_strategy") == gas_price_strategy_middleware ) - assert ( - w3.middleware_onion.get("name_to_address").__name__ - == ens_name_to_address_middleware(w3).__name__ - ) + assert w3.middleware_onion.get("name_to_address") == ens_name_to_address_middleware assert w3.middleware_onion.get("attrdict") == attrdict_middleware assert w3.middleware_onion.get("validation") == validation_middleware assert w3.middleware_onion.get("gas_estimate") == buffered_gas_estimate_middleware - assert w3.middleware_onion.get("abi") == abi_middleware def test_user_provided_session(): diff --git a/tests/ens/test_ens.py b/tests/ens/test_ens.py index af9cf99833..d5dea4397d 100644 --- a/tests/ens/test_ens.py +++ b/tests/ens/test_ens.py @@ -15,8 +15,8 @@ AsyncWeb3, ) from web3.middleware import ( - async_validation_middleware, gas_price_strategy_middleware, + validation_middleware, ) from web3.providers.eth_tester import ( AsyncEthereumTesterProvider, @@ -157,7 +157,7 @@ def local_async_w3(): def test_async_from_web3_inherits_web3_middlewares(local_async_w3): - test_middleware = async_validation_middleware + test_middleware = validation_middleware local_async_w3.middleware_onion.add(test_middleware, "test_middleware") ns = AsyncENS.from_web3(local_async_w3) diff --git a/tests/ens/test_utils.py b/tests/ens/test_utils.py index 0dd6948b09..66d9d37c26 100644 --- a/tests/ens/test_utils.py +++ b/tests/ens/test_utils.py @@ -39,8 +39,8 @@ def test_init_web3_adds_expected_middlewares(): w3 = init_web3() - middlewares = map(str, w3.manager.middleware_onion) - assert "stalecheck_middleware" in next(middlewares) + middlewares = w3.middleware_onion.middlewares + assert middlewares[0][1] == "stalecheck" @pytest.mark.parametrize( @@ -207,8 +207,8 @@ def test_label_to_hash_normalizes_name_using_ensip15(): @pytest.mark.asyncio async def test_init_async_web3_adds_expected_async_middlewares(): async_w3 = init_async_web3() - middlewares = map(str, async_w3.manager.middleware_onion) - assert "stalecheck_middleware" in next(middlewares) + middlewares = async_w3.middleware_onion.middlewares + assert middlewares[0][1] == "stalecheck" @pytest.mark.asyncio diff --git a/web3/_utils/module_testing/utils.py b/web3/_utils/module_testing/utils.py index ac0b54664b..0166e38eb5 100644 --- a/web3/_utils/module_testing/utils.py +++ b/web3/_utils/module_testing/utils.py @@ -1,4 +1,5 @@ import copy +from asyncio import iscoroutine, iscoroutinefunction from typing import ( Any, Callable, @@ -121,7 +122,14 @@ async def _async_mock_request_handler(self, method, params): response_dict = {"jsonrpc": "2.0", "id": request_id} if method in self.mock_results: - return merge(response_dict, {"result": self.mock_results[method]}) + mock_return = self.mock_results[method] + if isinstance(mock_return, Callable): + # handle callable to make things easier since we're mocking + mock_return = mock_return(method, params) + elif iscoroutinefunction(mock_return): + # this is the "correct" way to mock the async make_request + mock_return = await mock_return(method, params) + return merge(response_dict, {"result": mock_return}) elif method in self.mock_errors: error = self.mock_errors[method] if not isinstance(error, dict): diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index c8d9373d20..16a90570c0 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -36,6 +36,9 @@ from .filter import ( local_filter_middleware, ) +from .formatting import ( + construct_formatting_middleware, +) from .gas_price_strategy import ( GasPriceStrategyMiddleware, ) diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index 33987e1f78..4870272d9b 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -213,3 +213,6 @@ async def async_response_processor( self.error_formatters, response, ) + + +construct_formatting_middleware = FormattingMiddlewareBuilder.build From 731e3499f692086e57c08b87e8af1d23ecdaf7fe Mon Sep 17 00:00:00 2001 From: fselmo Date: Fri, 8 Dec 2023 15:21:33 -0700 Subject: [PATCH 12/24] Remove the necessity for the ``abi_middleware`` - closes #2995 - ``ens_name_to_address_middleware`` would return the ``bytes`` address, rather than checksummed address. Since this happens after the ``request_formatters`` have been applied, this would override the results from the ``ABI_REQUEST_FORMATTERS`` which take ``address`` types and format them to a checksummed hex address. --- web3/_utils/ens.py | 3 ++- web3/middleware/__init__.py | 3 --- web3/middleware/abi.py | 10 ---------- 3 files changed, 2 insertions(+), 14 deletions(-) delete mode 100644 web3/middleware/abi.py diff --git a/web3/_utils/ens.py b/web3/_utils/ens.py index a1bd8a1437..5db42531de 100644 --- a/web3/_utils/ens.py +++ b/web3/_utils/ens.py @@ -17,6 +17,7 @@ is_0x_prefixed, is_hex, is_hex_address, + to_checksum_address, ) from ens import ( @@ -51,7 +52,7 @@ def is_ens_name(value: Any) -> bool: def validate_name_has_address(ens: ENS, name: str) -> ChecksumAddress: addr = ens.address(name) if addr: - return addr + return to_checksum_address(addr) else: raise NameNotFound(f"Could not find address for name {name!r}") diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index 16a90570c0..bca997ace3 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -12,9 +12,6 @@ RPCResponse, ) -from .abi import ( - abi_middleware, -) from .attrdict import ( attrdict_middleware, ) diff --git a/web3/middleware/abi.py b/web3/middleware/abi.py deleted file mode 100644 index 64b786ae75..0000000000 --- a/web3/middleware/abi.py +++ /dev/null @@ -1,10 +0,0 @@ -from web3._utils.method_formatters import ( - ABI_REQUEST_FORMATTERS, -) -from web3.middleware.formatting import ( - FormattingMiddlewareBuilder, -) - -abi_middleware = FormattingMiddlewareBuilder.build( - request_formatters=ABI_REQUEST_FORMATTERS -) From caaa6185991105adb80226bb89cb937f22713822 Mon Sep 17 00:00:00 2001 From: fselmo Date: Wed, 13 Dec 2023 11:08:46 -0700 Subject: [PATCH 13/24] Reinstate the http retry request tests as configuration --- .../middleware/test_http_request_retry.py | 195 ------------------ .../core/providers/test_http_request_retry.py | 185 +++++++++++++++++ web3/providers/rpc/rpc.py | 3 + web3/providers/rpc/utils.py | 33 ++- 4 files changed, 212 insertions(+), 204 deletions(-) delete mode 100644 tests/core/middleware/test_http_request_retry.py create mode 100644 tests/core/providers/test_http_request_retry.py diff --git a/tests/core/middleware/test_http_request_retry.py b/tests/core/middleware/test_http_request_retry.py deleted file mode 100644 index f1cdb4f6d7..0000000000 --- a/tests/core/middleware/test_http_request_retry.py +++ /dev/null @@ -1,195 +0,0 @@ -# TODO: Redo these tests but as the provider configuration for retrying requests -# import pytest -# from unittest.mock import ( -# Mock, -# patch, -# ) -# -# import aiohttp -# import pytest_asyncio -# from requests.exceptions import ( -# ConnectionError, -# HTTPError, -# Timeout, -# TooManyRedirects, -# ) -# -# import web3 -# from web3 import ( -# AsyncHTTPProvider, -# AsyncWeb3, -# ) -# from web3.providers import ( -# HTTPProvider, -# IPCProvider, -# ) -# from web3.providers.rpc.utils import check_if_retry_on_failure -# -# -# def test_check_if_retry_on_failure_false(): -# methods = [ -# "eth_sendTransaction", -# "personal_signAndSendTransaction", -# "personal_sendTransaction", -# ] -# -# for method in methods: -# assert not check_if_retry_on_failure(method) -# -# -# def test_check_if_retry_on_failure_true(): -# method = "eth_getBalance" -# assert check_if_retry_on_failure(method) -# -# -# @patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) -# def test_check_send_transaction_called_once( -# make_post_request_mock, exception_retry_request_setup -# ): -# method = "eth_sendTransaction" -# params = [ -# { -# "to": "0x0", -# "value": 1, -# } -# ] -# -# with pytest.raises(ConnectionError): -# exception_retry_request_setup(method, params) -# assert make_post_request_mock.call_count == 1 -# -# -# @patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) -# def test_valid_method_retried(make_post_request_mock, exception_retry_request_setup): -# method = "eth_getBalance" -# params = [] -# -# with pytest.raises(ConnectionError): -# exception_retry_request_setup(method, params) -# assert make_post_request_mock.call_count == 5 -# -# -# def test_is_strictly_default_http_middleware(): -# w3 = HTTPProvider() -# assert "http_retry_request" in w3.middlewares -# -# w3 = IPCProvider() -# assert "http_retry_request" not in w3.middlewares -# -# -# @patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) -# def test_check_with_all_middlewares(make_post_request_mock): -# provider = HTTPProvider() -# w3 = web3.Web3(provider) -# with pytest.raises(ConnectionError): -# w3.eth.block_number -# assert make_post_request_mock.call_count == 5 -# -# -# @patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) -# def test_exception_retry_middleware_with_allow_list_kwarg( -# make_post_request_mock, exception_retry_request_setup -# ): -# w3 = Mock() -# provider = HTTPProvider() -# errors = (ConnectionError, HTTPError, Timeout, TooManyRedirects) -# setup = exception_retry_middleware( -# provider.make_request, w3, errors, 5, allow_list=["test_userProvidedMethod"] -# ) -# setup.w3 = w3 -# with pytest.raises(ConnectionError): -# setup("test_userProvidedMethod", []) -# assert make_post_request_mock.call_count == 5 -# -# make_post_request_mock.reset_mock() -# with pytest.raises(ConnectionError): -# setup("eth_getBalance", []) -# assert make_post_request_mock.call_count == 1 -# -# -# # -- async -- # -# -# -# ASYNC_TEST_RETRY_COUNT = 3 -# -# -# @pytest_asyncio.fixture -# async def async_exception_retry_request_setup(): -# w3 = Mock() -# provider = AsyncHTTPProvider() -# setup = await async_exception_retry_middleware( -# provider.make_request, -# w3, -# (TimeoutError, aiohttp.ClientError), -# ASYNC_TEST_RETRY_COUNT, -# 0.1, -# ) -# setup.w3 = w3 -# return setup -# -# -# @pytest.mark.asyncio -# @pytest.mark.parametrize( -# "error", -# ( -# TimeoutError, -# aiohttp.ClientError, # general base class for all aiohttp errors -# # ClientError subclasses -# aiohttp.ClientConnectionError, -# aiohttp.ServerTimeoutError, -# aiohttp.ClientOSError, -# ), -# ) -# async def test_async_check_retry_middleware(error, async_exception_retry_request_setup): -# with patch( -# "web3.providers.async_rpc.async_make_post_request" -# ) as async_make_post_request_mock: -# async_make_post_request_mock.side_effect = error -# -# with pytest.raises(error): -# await async_exception_retry_request_setup("eth_getBalance", []) -# assert async_make_post_request_mock.call_count == ASYNC_TEST_RETRY_COUNT -# -# -# @pytest.mark.asyncio -# async def test_async_check_without_retry_middleware(): -# with patch( -# "web3.providers.async_rpc.async_make_post_request" -# ) as async_make_post_request_mock: -# async_make_post_request_mock.side_effect = TimeoutError -# provider = AsyncHTTPProvider() -# w3 = AsyncWeb3(provider) -# w3.provider._middlewares = () -# -# with pytest.raises(TimeoutError): -# await w3.eth.block_number -# assert async_make_post_request_mock.call_count == 1 -# -# -# @pytest.mark.asyncio -# async def test_async_exception_retry_middleware_with_allow_list_kwarg(): -# w3 = Mock() -# provider = AsyncHTTPProvider() -# setup = await async_exception_retry_middleware( -# provider.make_request, -# w3, -# (TimeoutError, aiohttp.ClientError), -# retries=ASYNC_TEST_RETRY_COUNT, -# backoff_factor=0.1, -# allow_list=["test_userProvidedMethod"], -# ) -# setup.w3 = w3 -# -# with patch( -# "web3.providers.async_rpc.async_make_post_request" -# ) as async_make_post_request_mock: -# async_make_post_request_mock.side_effect = TimeoutError -# -# with pytest.raises(TimeoutError): -# await setup("test_userProvidedMethod", []) -# assert async_make_post_request_mock.call_count == ASYNC_TEST_RETRY_COUNT -# -# async_make_post_request_mock.reset_mock() -# with pytest.raises(TimeoutError): -# await setup("eth_getBalance", []) -# assert async_make_post_request_mock.call_count == 1 diff --git a/tests/core/providers/test_http_request_retry.py b/tests/core/providers/test_http_request_retry.py new file mode 100644 index 0000000000..ad27ccafba --- /dev/null +++ b/tests/core/providers/test_http_request_retry.py @@ -0,0 +1,185 @@ +import pytest +from unittest.mock import ( + Mock, + patch, +) + +import aiohttp +import pytest_asyncio +from requests.exceptions import ( + ConnectionError, + HTTPError, + Timeout, + TooManyRedirects, +) + +import web3 +from web3 import ( + AsyncHTTPProvider, + AsyncWeb3, + Web3, +) +from web3.providers import ( + HTTPProvider, + IPCProvider, +) +from web3.providers.rpc.utils import ( + ExceptionRetryConfiguration, + check_if_retry_on_failure, +) +from web3.types import RPCEndpoint + + +TEST_RETRY_COUNT = 3 + + +@pytest.fixture +def w3(): + errors = (ConnectionError, HTTPError, Timeout, TooManyRedirects) + config = ExceptionRetryConfiguration(errors=errors, retries=TEST_RETRY_COUNT) + return Web3(HTTPProvider(exception_retry_configuration=config)) + + +def test_default_request_retry_configuration_for_http_provider(): + w3 = Web3(HTTPProvider()) + assert w3.provider.exception_retry_configuration == ExceptionRetryConfiguration() + + +def test_check_if_retry_on_failure_false(): + methods = [ + "eth_sendTransaction", + "personal_signAndSendTransaction", + "personal_sendTransaction", + ] + + for method in methods: + assert not check_if_retry_on_failure(method) + + +def test_check_if_retry_on_failure_true(): + method = "eth_getBalance" + assert check_if_retry_on_failure(method) + + +@patch("web3.providers.rpc.rpc.make_post_request", side_effect=ConnectionError) +def test_check_send_transaction_called_once(make_post_request_mock, w3): + with pytest.raises(ConnectionError): + w3.provider.make_request( + RPCEndpoint("eth_sendTransaction"), [{"to": f"0x{'00' * 20}", "value": 1}] + ) + assert make_post_request_mock.call_count == 1 + + +@patch("web3.providers.rpc.rpc.make_post_request", side_effect=ConnectionError) +def test_valid_method_retried(make_post_request_mock, w3): + with pytest.raises(ConnectionError): + w3.provider.make_request(RPCEndpoint("eth_getBalance"), [f"0x{'00' * 20}"]) + assert make_post_request_mock.call_count == TEST_RETRY_COUNT + + +def test_exception_retry_config_is_strictly_on_http_provider(): + w3 = Web3(HTTPProvider()) + assert hasattr(w3.provider, "exception_retry_configuration") + + w3 = Web3(IPCProvider()) + assert not hasattr(w3.provider, "exception_retry_configuration") + + +@patch("web3.providers.rpc.rpc.make_post_request", side_effect=ConnectionError) +def test_exception_retry_middleware_with_allow_list_kwarg(make_post_request_mock): + config = ExceptionRetryConfiguration( + errors=(ConnectionError, HTTPError, Timeout, TooManyRedirects), + retries=TEST_RETRY_COUNT, + method_allowlist=["test_userProvidedMethod"], + ) + w3 = Web3(HTTPProvider(exception_retry_configuration=config)) + + with pytest.raises(ConnectionError): + w3.provider.make_request(RPCEndpoint("test_userProvidedMethod"), []) + assert make_post_request_mock.call_count == TEST_RETRY_COUNT + + make_post_request_mock.reset_mock() + with pytest.raises(ConnectionError): + w3.provider.make_request(RPCEndpoint("eth_getBalance"), []) + assert make_post_request_mock.call_count == 1 + + +# -- async -- # + + +@pytest.fixture +def async_w3(): + errors = (aiohttp.ClientError, TimeoutError) + config = ExceptionRetryConfiguration(errors=errors, retries=TEST_RETRY_COUNT) + return AsyncWeb3(AsyncHTTPProvider(exception_retry_configuration=config)) + + +@pytest.mark.asyncio +async def test_async_default_request_retry_configuration_for_http_provider(): + async_w3 = AsyncWeb3(AsyncHTTPProvider()) + assert ( + async_w3.provider.exception_retry_configuration == ExceptionRetryConfiguration() + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "error", + ( + TimeoutError, + aiohttp.ClientError, # general base class for all aiohttp errors + # ClientError subclasses + aiohttp.ClientConnectionError, + aiohttp.ServerTimeoutError, + aiohttp.ClientOSError, + ), +) +async def test_async_check_retry_middleware(async_w3, error): + with patch( + "web3.providers.rpc.async_rpc.async_make_post_request" + ) as async_make_post_request_mock: + async_make_post_request_mock.side_effect = error + + with pytest.raises(error): + await async_w3.provider.make_request(RPCEndpoint("eth_getBalance"), []) + assert async_make_post_request_mock.call_count == TEST_RETRY_COUNT + + +@pytest.mark.asyncio +async def test_async_check_without_retry_config(): + w3 = AsyncWeb3(AsyncHTTPProvider(exception_retry_configuration=None)) + + with patch( + "web3.providers.rpc.async_rpc.async_make_post_request" + ) as async_make_post_request_mock: + async_make_post_request_mock.side_effect = TimeoutError + + with pytest.raises(TimeoutError): + await w3.eth.block_number + assert async_make_post_request_mock.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_exception_retry_middleware_with_allow_list_kwarg(): + config = ExceptionRetryConfiguration( + errors=(aiohttp.ClientError, TimeoutError), + retries=TEST_RETRY_COUNT, + method_allowlist=["test_userProvidedMethod"], + ) + async_w3 = AsyncWeb3(AsyncHTTPProvider(exception_retry_configuration=config)) + + with patch( + "web3.providers.rpc.async_rpc.async_make_post_request" + ) as async_make_post_request_mock: + async_make_post_request_mock.side_effect = TimeoutError + + with pytest.raises(TimeoutError): + await async_w3.provider.make_request( + RPCEndpoint("test_userProvidedMethod"), [] + ) + assert async_make_post_request_mock.call_count == TEST_RETRY_COUNT + + async_make_post_request_mock.reset_mock() + with pytest.raises(TimeoutError): + await async_w3.provider.make_request(RPCEndpoint("eth_getBalance"), []) + assert async_make_post_request_mock.call_count == 1 diff --git a/web3/providers/rpc/rpc.py b/web3/providers/rpc/rpc.py index 9d7aa05997..92238c72f0 100644 --- a/web3/providers/rpc/rpc.py +++ b/web3/providers/rpc/rpc.py @@ -48,11 +48,14 @@ class HTTPProvider(JSONBaseProvider): logger = logging.getLogger("web3.providers.HTTPProvider") endpoint_uri = None + _request_args = None _request_kwargs = None # type ignored b/c conflict with _middlewares attr on BaseProvider _middlewares: Tuple[Middleware, ...] = NamedElementOnion([]) # type: ignore # noqa: E501 + exception_retry_configuration: Optional[ExceptionRetryConfiguration] = None + def __init__( self, endpoint_uri: Optional[Union[URI, str]] = None, diff --git a/web3/providers/rpc/utils.py b/web3/providers/rpc/utils.py index 9d87ad5bc4..4236b01523 100644 --- a/web3/providers/rpc/utils.py +++ b/web3/providers/rpc/utils.py @@ -1,4 +1,5 @@ from typing import ( + Optional, Sequence, Type, ) @@ -79,18 +80,32 @@ def check_if_retry_on_failure( if allowlist is None: allowlist = REQUEST_RETRY_ALLOWLIST - if method in allowlist or method.split("_")[0]: + if method in allowlist or method.split("_")[0] in allowlist: return True else: return False class ExceptionRetryConfiguration(BaseModel): - errors: Sequence[Type[BaseException]] = ( - ConnectionError, - requests.HTTPError, - requests.Timeout, - ) - retries: int = 5 - backoff_factor: float = 0.5 - method_allowlist: Sequence[str] = REQUEST_RETRY_ALLOWLIST + errors: Sequence[Type[BaseException]] + retries: int + backoff_factor: float + method_allowlist: Sequence[str] + + def __init__( + self, + errors: Sequence[Type[BaseException]] = ( + ConnectionError, + requests.HTTPError, + requests.Timeout, + ), + retries: int = 5, + backoff_factor: float = 0.5, + method_allowlist: Sequence[str] = None, + ): + super().__init__( + errors=errors, + retries=retries, + backoff_factor=backoff_factor, + method_allowlist=method_allowlist or REQUEST_RETRY_ALLOWLIST, + ) From 89c6a9c32b1fcffc1c6bbc29fc17f8a7514df2cf Mon Sep 17 00:00:00 2001 From: fselmo Date: Thu, 14 Dec 2023 16:44:22 -0700 Subject: [PATCH 14/24] Fix type hinting for refactor --- ens/async_ens.py | 6 +- ens/ens.py | 5 +- ens/utils.py | 7 +- tests/core/contracts/test_contract_example.py | 3 - tests/core/eth-module/test_transactions.py | 1 - tests/core/filtering/utils.py | 3 - tests/core/manager/conftest.py | 4 +- .../core/manager/test_default_middlewares.py | 2 +- .../test_middleware_can_be_stateful.py | 2 +- .../core/middleware/test_filter_middleware.py | 3 - .../middleware/test_formatting_middleware.py | 4 +- .../middleware/test_gas_price_strategy.py | 13 +-- tests/core/middleware/test_stalecheck.py | 4 +- .../middleware/test_transaction_signing.py | 4 +- .../providers/test_async_http_provider.py | 2 +- tests/core/providers/test_http_provider.py | 3 +- .../core/providers/test_http_request_retry.py | 8 +- tests/core/providers/test_ipc_provider.py | 4 +- tests/ens/conftest.py | 3 +- web3/_utils/caching.py | 56 ++++++++----- web3/_utils/module_testing/eth_module.py | 2 +- web3/_utils/module_testing/utils.py | 80 +++++++++++------- web3/datastructures.py | 4 +- web3/main.py | 7 +- web3/manager.py | 33 +++----- web3/middleware/__init__.py | 33 ++++---- web3/middleware/attrdict.py | 10 +-- web3/middleware/base.py | 37 ++++++--- web3/middleware/buffered_gas_estimate.py | 5 +- web3/middleware/filter.py | 83 +++++++++++-------- web3/middleware/formatting.py | 17 ++-- web3/middleware/gas_price_strategy.py | 9 +- web3/middleware/names.py | 7 +- web3/middleware/signing.py | 27 +++--- web3/middleware/stalecheck.py | 16 +++- web3/providers/async_base.py | 27 ++---- web3/providers/base.py | 35 ++++---- web3/providers/eth_tester/main.py | 62 ++++++++++++-- web3/providers/rpc/async_rpc.py | 7 +- web3/providers/rpc/rpc.py | 17 ++-- web3/providers/rpc/utils.py | 1 - web3/providers/websocket/websocket_v2.py | 2 - web3/tools/benchmark/main.py | 7 +- web3/types.py | 14 +--- 44 files changed, 389 insertions(+), 290 deletions(-) diff --git a/ens/async_ens.py b/ens/async_ens.py index 48a2646a12..f476df4dad 100644 --- a/ens/async_ens.py +++ b/ens/async_ens.py @@ -71,12 +71,14 @@ AsyncContractFunction, ) from web3.main import AsyncWeb3 # noqa: F401 + from web3.middleware.base import ( # noqa: F401 + Middleware, + ) from web3.providers import ( # noqa: F401 AsyncBaseProvider, BaseProvider, ) from web3.types import ( # noqa: F401 - AsyncMiddleware, TxParams, ) @@ -98,7 +100,7 @@ def __init__( self, provider: "AsyncBaseProvider" = cast("AsyncBaseProvider", default), addr: ChecksumAddress = None, - middlewares: Optional[Sequence[Tuple["AsyncMiddleware", str]]] = None, + middlewares: Optional[Sequence[Tuple["Middleware", str]]] = None, ) -> None: """ :param provider: a single provider used to connect to Ethereum diff --git a/ens/ens.py b/ens/ens.py index 97b24a2896..d8e334678a 100644 --- a/ens/ens.py +++ b/ens/ens.py @@ -71,11 +71,14 @@ Contract, ContractFunction, ) + from web3.middleware.base import ( # noqa: F401 + Middleware, + ) from web3.providers import ( # noqa: F401 BaseProvider, ) from web3.types import ( # noqa: F401 - Middleware, + MakeRequestFn, TxParams, ) diff --git a/ens/utils.py b/ens/utils.py index 8e99c81018..9cff648e11 100644 --- a/ens/utils.py +++ b/ens/utils.py @@ -59,14 +59,15 @@ AsyncWeb3, Web3 as _Web3, ) + from web3.middleware.base import ( # noqa: F401 + Middleware, + ) from web3.providers import ( # noqa: F401 AsyncBaseProvider, BaseProvider, ) from web3.types import ( # noqa: F401 ABIFunction, - AsyncMiddleware, - Middleware, RPCEndpoint, ) @@ -299,7 +300,7 @@ def get_abi_output_types(abi: "ABIFunction") -> List[str]: def init_async_web3( provider: "AsyncBaseProvider" = cast("AsyncBaseProvider", default), - middlewares: Optional[Sequence[Tuple["AsyncMiddleware", str]]] = (), + middlewares: Optional[Sequence[Tuple["Middleware", str]]] = (), ) -> "AsyncWeb3": from web3 import ( AsyncWeb3 as AsyncWeb3Main, diff --git a/tests/core/contracts/test_contract_example.py b/tests/core/contracts/test_contract_example.py index 28a52cba80..738b365e4f 100644 --- a/tests/core/contracts/test_contract_example.py +++ b/tests/core/contracts/test_contract_example.py @@ -9,9 +9,6 @@ EthereumTesterProvider, Web3, ) -from web3.eth import ( - AsyncEth, -) from web3.providers.eth_tester.main import ( AsyncEthereumTesterProvider, ) diff --git a/tests/core/eth-module/test_transactions.py b/tests/core/eth-module/test_transactions.py index dbb668c40a..7484fd29d2 100644 --- a/tests/core/eth-module/test_transactions.py +++ b/tests/core/eth-module/test_transactions.py @@ -1,6 +1,5 @@ import collections import itertools - import pytest from eth_utils import ( diff --git a/tests/core/filtering/utils.py b/tests/core/filtering/utils.py index b3e27036ef..d7b0f8cbe0 100644 --- a/tests/core/filtering/utils.py +++ b/tests/core/filtering/utils.py @@ -2,9 +2,6 @@ AsyncWeb3, Web3, ) -from web3.eth import ( - AsyncEth, -) from web3.middleware import ( local_filter_middleware, ) diff --git a/tests/core/manager/conftest.py b/tests/core/manager/conftest.py index d522e3d438..cc2c86c063 100644 --- a/tests/core/manager/conftest.py +++ b/tests/core/manager/conftest.py @@ -1,7 +1,9 @@ import itertools import pytest -from web3.middleware import Web3Middleware +from web3.middleware.base import ( + Web3Middleware, +) @pytest.fixture diff --git a/tests/core/manager/test_default_middlewares.py b/tests/core/manager/test_default_middlewares.py index b55815058c..179fd50ecb 100644 --- a/tests/core/manager/test_default_middlewares.py +++ b/tests/core/manager/test_default_middlewares.py @@ -19,6 +19,6 @@ def test_default_sync_middlewares(w3): (buffered_gas_estimate_middleware, "gas_estimate"), ] - default_middlewares = RequestManager.default_middlewares(w3) + default_middlewares = RequestManager.get_default_middlewares() assert default_middlewares == expected_middlewares diff --git a/tests/core/manager/test_middleware_can_be_stateful.py b/tests/core/manager/test_middleware_can_be_stateful.py index e8613b49c6..ae8ef8314c 100644 --- a/tests/core/manager/test_middleware_can_be_stateful.py +++ b/tests/core/manager/test_middleware_can_be_stateful.py @@ -1,7 +1,7 @@ from web3.manager import ( RequestManager, ) -from web3.middleware import ( +from web3.middleware.base import ( Web3Middleware, ) from web3.providers import ( diff --git a/tests/core/middleware/test_filter_middleware.py b/tests/core/middleware/test_filter_middleware.py index 3f7ad21583..e24e249daf 100644 --- a/tests/core/middleware/test_filter_middleware.py +++ b/tests/core/middleware/test_filter_middleware.py @@ -12,9 +12,6 @@ from web3.datastructures import ( AttributeDict, ) -from web3.eth import ( - AsyncEth, -) from web3.middleware import ( attrdict_middleware, local_filter_middleware, diff --git a/tests/core/middleware/test_formatting_middleware.py b/tests/core/middleware/test_formatting_middleware.py index 1bd5cec913..969f2bfb46 100644 --- a/tests/core/middleware/test_formatting_middleware.py +++ b/tests/core/middleware/test_formatting_middleware.py @@ -12,7 +12,9 @@ from web3.providers.base import ( BaseProvider, ) -from web3.types import RPCEndpoint +from web3.types import ( + RPCEndpoint, +) class DummyProvider(BaseProvider): diff --git a/tests/core/middleware/test_gas_price_strategy.py b/tests/core/middleware/test_gas_price_strategy.py index 54a993cb6a..8f62e4a270 100644 --- a/tests/core/middleware/test_gas_price_strategy.py +++ b/tests/core/middleware/test_gas_price_strategy.py @@ -3,7 +3,9 @@ Mock, ) -from toolz import merge +from toolz import ( + merge, +) from web3.middleware import ( gas_price_strategy_middleware, @@ -11,23 +13,22 @@ @pytest.fixture -def the_gas_price_strategy_middleware(w3): +def the_gas_price_strategy_middleware(): w3 = Mock() initialized = gas_price_strategy_middleware(w3) return initialized def test_gas_price_generated(the_gas_price_strategy_middleware): - the_gas_price_strategy_middleware._w3.eth.generate_gas_price.return_value = 5 + w3 = the_gas_price_strategy_middleware._w3 + w3.eth.generate_gas_price.return_value = 5 make_request = Mock() inner = the_gas_price_strategy_middleware._wrap_make_request(make_request) method, dict_param = "eth_sendTransaction", {"to": "0x0", "value": 1} inner(method, (dict_param,)) - the_gas_price_strategy_middleware._w3.eth.generate_gas_price.assert_called_once_with( - dict_param - ) + w3.eth.generate_gas_price.assert_called_once_with(dict_param) make_request.assert_called_once_with( method, (merge(dict_param, {"gasPrice": "0x5"}),) ) diff --git a/tests/core/middleware/test_stalecheck.py b/tests/core/middleware/test_stalecheck.py index a722cec56a..448332d266 100644 --- a/tests/core/middleware/test_stalecheck.py +++ b/tests/core/middleware/test_stalecheck.py @@ -163,7 +163,9 @@ async def async_request_middleware(allowable_delay): @pytest.mark.asyncio @min_version async def test_async_stalecheck_pass(async_request_middleware): - from unittest.mock import AsyncMock + from unittest.mock import ( + AsyncMock, + ) with patch("web3.middleware.stalecheck._is_fresh", return_value=True): make_request = AsyncMock() diff --git a/tests/core/middleware/test_transaction_signing.py b/tests/core/middleware/test_transaction_signing.py index b0297ebc41..e230556dc7 100644 --- a/tests/core/middleware/test_transaction_signing.py +++ b/tests/core/middleware/test_transaction_signing.py @@ -38,7 +38,9 @@ from web3.middleware import ( construct_sign_and_send_raw_middleware, ) -from web3.middleware.signing import gen_normalized_accounts +from web3.middleware.signing import ( + gen_normalized_accounts, +) from web3.providers import ( AsyncBaseProvider, BaseProvider, diff --git a/tests/core/providers/test_async_http_provider.py b/tests/core/providers/test_async_http_provider.py index 4f1807ba43..892aff0bad 100644 --- a/tests/core/providers/test_async_http_provider.py +++ b/tests/core/providers/test_async_http_provider.py @@ -25,8 +25,8 @@ from web3.middleware import ( attrdict_middleware, buffered_gas_estimate_middleware, - gas_price_strategy_middleware, ens_name_to_address_middleware, + gas_price_strategy_middleware, validation_middleware, ) from web3.net import ( diff --git a/tests/core/providers/test_http_provider.py b/tests/core/providers/test_http_provider.py index 245c8f35d9..126ab5ffba 100644 --- a/tests/core/providers/test_http_provider.py +++ b/tests/core/providers/test_http_provider.py @@ -26,11 +26,10 @@ GethTxPool, ) from web3.middleware import ( - abi_middleware, attrdict_middleware, buffered_gas_estimate_middleware, - gas_price_strategy_middleware, ens_name_to_address_middleware, + gas_price_strategy_middleware, validation_middleware, ) from web3.net import ( diff --git a/tests/core/providers/test_http_request_retry.py b/tests/core/providers/test_http_request_retry.py index ad27ccafba..22d0c2debd 100644 --- a/tests/core/providers/test_http_request_retry.py +++ b/tests/core/providers/test_http_request_retry.py @@ -1,11 +1,9 @@ import pytest from unittest.mock import ( - Mock, patch, ) import aiohttp -import pytest_asyncio from requests.exceptions import ( ConnectionError, HTTPError, @@ -13,7 +11,6 @@ TooManyRedirects, ) -import web3 from web3 import ( AsyncHTTPProvider, AsyncWeb3, @@ -27,8 +24,9 @@ ExceptionRetryConfiguration, check_if_retry_on_failure, ) -from web3.types import RPCEndpoint - +from web3.types import ( + RPCEndpoint, +) TEST_RETRY_COUNT = 3 diff --git a/tests/core/providers/test_ipc_provider.py b/tests/core/providers/test_ipc_provider.py index 6847e91f06..764f8f5833 100644 --- a/tests/core/providers/test_ipc_provider.py +++ b/tests/core/providers/test_ipc_provider.py @@ -17,7 +17,9 @@ from web3.providers.ipc import ( IPCProvider, ) -from web3.types import RPCEndpoint +from web3.types import ( + RPCEndpoint, +) @pytest.fixture diff --git a/tests/ens/conftest.py b/tests/ens/conftest.py index 919eb463b3..793195b4e4 100644 --- a/tests/ens/conftest.py +++ b/tests/ens/conftest.py @@ -362,8 +362,7 @@ def TEST_ADDRESS(address_conversion_func): @pytest_asyncio.fixture(scope="session") def async_w3(): - provider = AsyncEthereumTesterProvider() - _async_w3 = AsyncWeb3(provider, middlewares=provider.middlewares) + _async_w3 = AsyncWeb3(AsyncEthereumTesterProvider()) return _async_w3 diff --git a/web3/_utils/caching.py b/web3/_utils/caching.py index b95eae603a..8d958607e7 100644 --- a/web3/_utils/caching.py +++ b/web3/_utils/caching.py @@ -1,16 +1,16 @@ import collections -import copy import hashlib from typing import ( TYPE_CHECKING, Any, Callable, + Coroutine, List, Tuple, + TypeVar, + Union, ) -from toolz import merge - from eth_utils import ( is_boolean, is_bytes, @@ -24,13 +24,21 @@ if TYPE_CHECKING: from web3.providers import ( # noqa: F401 + AsyncBaseProvider, BaseProvider, ) - from web3.types import ( + from web3.types import ( # noqa: F401 + AsyncMakeRequestFn, + MakeRequestFn, RPCEndpoint, + RPCResponse, ) +SYNC_PROVIDER_TYPE = TypeVar("SYNC_PROVIDER_TYPE", bound="BaseProvider") +ASYNC_PROVIDER_TYPE = TypeVar("ASYNC_PROVIDER_TYPE", bound="AsyncBaseProvider") + + def generate_cache_key(value: Any) -> str: """ Generates a cache key for the *args and **kwargs @@ -66,40 +74,50 @@ def __init__( self.middleware_response_processors: List[Callable[..., Any]] = [] -def is_cacheable_request(provider: "BaseProvider", method: "RPCEndpoint") -> bool: +def is_cacheable_request( + provider: Union[ASYNC_PROVIDER_TYPE, SYNC_PROVIDER_TYPE], method: "RPCEndpoint" +) -> bool: if provider.cache_allowed_requests and method in provider.cacheable_requests: return True + return False # -- request caching decorators -- # -def handle_request_caching(func): - def wrapper(*args, **kwargs): - # args=(self, method, params) - where "self" should be BaseProvider instance - provider, method, params = args - +def handle_request_caching( + func: Callable[[SYNC_PROVIDER_TYPE, "RPCEndpoint", Any], "RPCResponse"] +) -> Callable[..., "RPCResponse"]: + def wrapper( + provider: SYNC_PROVIDER_TYPE, method: "RPCEndpoint", params: Any + ) -> "RPCResponse": if is_cacheable_request(provider, method): request_cache = provider._request_cache - cache_key = generate_cache_key((method, params, kwargs)) + cache_key = generate_cache_key((method, params)) cache_result = request_cache.get_cache_entry(cache_key) if cache_result is not None: return cache_result else: - response = func(*args, **kwargs) + response = func(provider, method, params) request_cache.cache(cache_key, response) return response else: - return func(*args, **kwargs) + return func(provider, method, params) return wrapper -def async_handle_request_caching(func): - async def wrapper(*args, **kwargs): - # args=(self, method, params) - where "self" should be the provider instance - provider, method, params = args +# -- async -- # + +def async_handle_request_caching( + func: Callable[ + [ASYNC_PROVIDER_TYPE, "RPCEndpoint", Any], Coroutine[Any, Any, "RPCResponse"] + ], +) -> Callable[..., Coroutine[Any, Any, "RPCResponse"]]: + async def wrapper( + provider: ASYNC_PROVIDER_TYPE, method: "RPCEndpoint", params: Any + ) -> "RPCResponse": if is_cacheable_request(provider, method): request_cache = provider._request_cache cache_key = generate_cache_key((method, params)) @@ -107,10 +125,10 @@ async def wrapper(*args, **kwargs): if cache_result is not None: return cache_result else: - response = await func(*args, **kwargs) + response = await func(provider, method, params) request_cache.cache(cache_key, response) return response else: - return await func(*args, **kwargs) + return await func(provider, method, params) return wrapper diff --git a/web3/_utils/module_testing/eth_module.py b/web3/_utils/module_testing/eth_module.py index 34659e954e..feee67ebd7 100644 --- a/web3/_utils/module_testing/eth_module.py +++ b/web3/_utils/module_testing/eth_module.py @@ -651,7 +651,7 @@ async def test_validation_middleware_chain_id_mismatch( @pytest.mark.asyncio async def test_geth_poa_middleware( - self, async_w3: "AsyncWeb3", request_mocker + self, async_w3: "AsyncWeb3", request_mocker: Type[RequestMocker] ) -> None: async_w3.middleware_onion.inject(extradata_to_poa_middleware, "poa", layer=0) extra_data = f"0x{'ff' * 33}" diff --git a/web3/_utils/module_testing/utils.py b/web3/_utils/module_testing/utils.py index 0166e38eb5..f7bd59de41 100644 --- a/web3/_utils/module_testing/utils.py +++ b/web3/_utils/module_testing/utils.py @@ -1,29 +1,37 @@ +from asyncio import ( + iscoroutinefunction, +) import copy -from asyncio import iscoroutine, iscoroutinefunction from typing import ( + TYPE_CHECKING, Any, - Callable, Dict, - TypeVar, Union, + cast, ) from toolz import ( merge, ) -from web3 import ( - AsyncWeb3, - Web3, -) from web3.exceptions import ( Web3ValidationError, ) -from web3.types import ( - RPCEndpoint, -) -WEB3 = TypeVar("WEB3", Web3, AsyncWeb3) +if TYPE_CHECKING: + from web3 import ( # noqa: F401 + AsyncWeb3, + Web3, + ) + from web3._utils.compat import ( # noqa: F401 + Self, + ) + from web3.types import ( # noqa: F401 + AsyncMakeRequestFn, + MakeRequestFn, + RPCEndpoint, + RPCResponse, + ) class RequestMocker: @@ -49,29 +57,36 @@ def test_my_w3(w3, request_mocker): def __init__( self, - w3: WEB3, - mock_results: Dict[Union[RPCEndpoint, str], Any] = None, - mock_errors: Dict[Union[RPCEndpoint, str], Dict[str, Any]] = None, + w3: Union["AsyncWeb3", "Web3"], + mock_results: Dict[Union["RPCEndpoint", str], Any] = None, + mock_errors: Dict[Union["RPCEndpoint", str], Dict[str, Any]] = None, ): self.w3 = w3 self.mock_results = mock_results or {} self.mock_errors = mock_errors or {} - self._make_request = w3.provider.make_request - - def __enter__(self): - self.w3.provider.make_request = self._mock_request_handler + self._make_request: Union[ + "AsyncMakeRequestFn", "MakeRequestFn" + ] = w3.provider.make_request + def __enter__(self) -> "Self": + setattr(self.w3.provider, "make_request", self._mock_request_handler) # reset request func cache to re-build request_func with mocked make_request self.w3.provider._request_func_cache = (None, None) return self - def __exit__(self, exc_type, exc_value, traceback): - self.w3.provider.make_request = self._make_request + # define __exit__ with typing information + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + setattr(self.w3.provider, "make_request", self._make_request) # reset request func cache to re-build request_func with original make_request self.w3.provider._request_func_cache = (None, None) - def _mock_request_handler(self, method, params): + def _mock_request_handler( + self, method: "RPCEndpoint", params: Any + ) -> "RPCResponse": + self.w3 = cast("Web3", self.w3) + self._make_request = cast("MakeRequestFn", self._make_request) + if method not in self.mock_errors and method not in self.mock_results: return self._make_request(method, params) @@ -84,7 +99,7 @@ def _mock_request_handler(self, method, params): if method in self.mock_results: mock_return = self.mock_results[method] - if isinstance(mock_return, Callable): + if callable(mock_return): mock_return = mock_return(method, params) return merge(response_dict, {"result": mock_return}) elif method in self.mock_errors: @@ -97,20 +112,27 @@ def _mock_request_handler(self, method, params): response_dict, {"error": merge({"code": code, "message": message}, error)}, ) + else: + raise Exception("Invariant: unreachable code path") # -- async -- # - async def __aenter__(self): - self.w3.provider.make_request = self._async_mock_request_handler + async def __aenter__(self) -> "Self": + setattr(self.w3.provider, "make_request", self._async_mock_request_handler) # reset request func cache to re-build request_func with mocked make_request self.w3.provider._request_func_cache = (None, None) return self - async def __aexit__(self, exc_type, exc_value, traceback): - self.w3.provider.make_request = self._make_request + async def __aexit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + setattr(self.w3.provider, "make_request", self._make_request) # reset request func cache to re-build request_func with original make_request self.w3.provider._request_func_cache = (None, None) - async def _async_mock_request_handler(self, method, params): + async def _async_mock_request_handler( + self, method: "RPCEndpoint", params: Any + ) -> "RPCResponse": + self.w3 = cast("AsyncWeb3", self.w3) + self._make_request = cast("AsyncMakeRequestFn", self._make_request) + if method not in self.mock_errors and method not in self.mock_results: return await self._make_request(method, params) @@ -123,7 +145,7 @@ async def _async_mock_request_handler(self, method, params): if method in self.mock_results: mock_return = self.mock_results[method] - if isinstance(mock_return, Callable): + if callable(mock_return): # handle callable to make things easier since we're mocking mock_return = mock_return(method, params) elif iscoroutinefunction(mock_return): @@ -140,3 +162,5 @@ async def _async_mock_request_handler(self, method, params): response_dict, {"error": merge({"code": code, "message": message}, error)}, ) + else: + raise Exception("Invariant: unreachable code path") diff --git a/web3/datastructures.py b/web3/datastructures.py index a999da1db1..2aa3632d45 100644 --- a/web3/datastructures.py +++ b/web3/datastructures.py @@ -176,7 +176,7 @@ def add(self, element: TValue, name: Optional[TKey] = None) -> None: # handle unhashable types name.__hash__() except TypeError: - name = repr(name) + name = cast(TKey, repr(name)) if name in self._queue: if name is element: @@ -217,7 +217,7 @@ def inject( # handle unhashable types name.__hash__() except TypeError: - name = repr(name) + name = cast(TKey, repr(name)) self._queue.move_to_end(name, last=False) elif layer == len(self._queue): diff --git a/web3/main.py b/web3/main.py index 61461b319c..51a238d82f 100644 --- a/web3/main.py +++ b/web3/main.py @@ -102,6 +102,7 @@ from web3.manager import ( RequestManager as DefaultRequestManager, ) +from web3.middleware.base import MiddlewareOnion from web3.module import ( Module, ) @@ -139,8 +140,6 @@ Tracing, ) from web3.types import ( - AsyncMiddlewareOnion, - MiddlewareOnion, Wei, ) @@ -470,8 +469,8 @@ async def is_connected(self, show_traceback: bool = False) -> bool: return await self.provider.is_connected(show_traceback) @property - def middleware_onion(self) -> AsyncMiddlewareOnion: - return cast(AsyncMiddlewareOnion, self.manager.middleware_onion) + def middleware_onion(self) -> MiddlewareOnion: + return cast(MiddlewareOnion, self.manager.middleware_onion) @property def provider(self) -> AsyncBaseProvider: diff --git a/web3/manager.py b/web3/manager.py index a7c666a179..28428d7c7f 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -43,6 +43,10 @@ gas_price_strategy_middleware, validation_middleware, ) +from web3.middleware.base import ( + Middleware, + MiddlewareOnion, +) from web3.module import ( apply_result_formatters, ) @@ -51,10 +55,6 @@ PersistentConnectionProvider, ) from web3.types import ( - AsyncMiddleware, - AsyncMiddlewareOnion, - Middleware, - MiddlewareOnion, RPCEndpoint, RPCResponse, ) @@ -116,21 +116,15 @@ def apply_null_result_formatters( class RequestManager: - logger = logging.getLogger("web3.RequestManager") + logger = logging.getLogger("web3.manager.RequestManager") - middleware_onion: Union[ - MiddlewareOnion, AsyncMiddlewareOnion, NamedElementOnion[None, None] - ] + middleware_onion: Union["MiddlewareOnion", NamedElementOnion[None, None]] def __init__( self, w3: Union["AsyncWeb3", "Web3"], provider: Optional[Union["BaseProvider", "AsyncBaseProvider"]] = None, - middlewares: Optional[ - Union[ - Sequence[Tuple[Middleware, str]], Sequence[Tuple[AsyncMiddleware, str]] - ] - ] = None, + middlewares: Optional[Sequence[Tuple[Middleware, str]]] = None, ) -> None: self.w3 = w3 @@ -140,7 +134,7 @@ def __init__( self.provider = provider if middlewares is None: - middlewares = self.default_middlewares(w3) + middlewares = self.get_default_middlewares() self.middleware_onion = NamedElementOnion(middlewares) @@ -162,10 +156,9 @@ def provider(self, provider: Union["BaseProvider", "AsyncBaseProvider"]) -> None self._provider = provider @staticmethod - def default_middlewares(w3: "Web3") -> List[Tuple["Web3Middleware", str]]: + def get_default_middlewares() -> List[Tuple[Middleware, str]]: """ List the default middlewares for the request manager. - Leaving w3 unspecified will prevent the middleware from resolving names. Documentation should remain in sync with these defaults. """ return [ @@ -184,7 +177,7 @@ def _make_request( ) -> RPCResponse: provider = cast("BaseProvider", self.provider) request_func = provider.request_func( - cast("Web3", self.w3), cast(MiddlewareOnion, self.middleware_onion) + cast("Web3", self.w3), cast("MiddlewareOnion", self.middleware_onion) ) self.logger.debug(f"Making request. Method: {method}") return request_func(method, params) @@ -194,8 +187,7 @@ async def _coro_make_request( ) -> RPCResponse: provider = cast("AsyncBaseProvider", self.provider) request_func = await provider.request_func( - cast("AsyncWeb3", self.w3), - cast(AsyncMiddlewareOnion, self.middleware_onion), + cast("AsyncWeb3", self.w3), cast("MiddlewareOnion", self.middleware_onion) ) self.logger.debug(f"Making request. Method: {method}") return await request_func(method, params) @@ -323,8 +315,7 @@ async def coro_request( async def ws_send(self, method: RPCEndpoint, params: Any) -> RPCResponse: provider = cast(PersistentConnectionProvider, self._provider) request_func = await provider.request_func( - cast("AsyncWeb3", self.w3), - cast(AsyncMiddlewareOnion, self.middleware_onion), + cast("AsyncWeb3", self.w3), cast("MiddlewareOnion", self.middleware_onion) ) self.logger.debug( "Making request to open websocket connection - " diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index bca997ace3..671aad4753 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -4,19 +4,13 @@ Callable, Coroutine, Sequence, - Type, -) - -from web3.types import ( - RPCEndpoint, - RPCResponse, ) from .attrdict import ( attrdict_middleware, ) from .base import ( - Web3Middleware, + Middleware, ) from .buffered_gas_estimate import ( buffered_gas_estimate_middleware, @@ -56,16 +50,27 @@ from .validation import ( validation_middleware, ) +from ..types import ( + AsyncMakeRequestFn, + MakeRequestFn, +) + if TYPE_CHECKING: - from web3 import AsyncWeb3, Web3 + from web3 import ( + AsyncWeb3, + Web3, + ) + from web3.types import ( + RPCResponse, + ) def combine_middlewares( - middlewares: Sequence[Type[Web3Middleware]], + middlewares: Sequence[Middleware], w3: "Web3", - provider_request_fn: Callable[[RPCEndpoint, Any], Any], -) -> Callable[..., RPCResponse]: + provider_request_fn: MakeRequestFn, +) -> Callable[..., "RPCResponse"]: """ Returns a callable function which takes method and params as positional arguments and passes these args through the request processors, makes the request, and passes @@ -79,10 +84,10 @@ def combine_middlewares( async def async_combine_middlewares( - middlewares: Sequence[Type[Web3Middleware]], + middlewares: Sequence[Middleware], async_w3: "AsyncWeb3", - provider_request_fn: Callable[[RPCEndpoint, Any], Any], -) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: + provider_request_fn: AsyncMakeRequestFn, +) -> Callable[..., Coroutine[Any, Any, "RPCResponse"]]: """ Returns a callable function which takes method and params as positional arguments and passes these args through the request processors, makes the request, and passes diff --git a/web3/middleware/attrdict.py b/web3/middleware/attrdict.py index 931e2d06a1..20051a6c33 100644 --- a/web3/middleware/attrdict.py +++ b/web3/middleware/attrdict.py @@ -17,10 +17,6 @@ from web3.middleware.base import ( Web3Middleware, ) -from web3.types import ( - RPCEndpoint, - RPCResponse, -) if TYPE_CHECKING: from web3 import ( # noqa: F401 @@ -30,9 +26,13 @@ from web3.providers import ( # noqa: F401 PersistentConnectionProvider, ) + from web3.types import ( # noqa: F401 + RPCEndpoint, + RPCResponse, + ) -def _handle_async_response(response: RPCResponse) -> RPCResponse: +def _handle_async_response(response: "RPCResponse") -> "RPCResponse": if "result" in response: return assoc(response, "result", AttributeDict.recursive(response["result"])) elif "params" in response and "result" in response["params"]: diff --git a/web3/middleware/base.py b/web3/middleware/base.py index fbe12e58b6..e4c2a044cc 100644 --- a/web3/middleware/base.py +++ b/web3/middleware/base.py @@ -1,41 +1,44 @@ -import warnings -from abc import abstractmethod +from abc import ( + abstractmethod, +) from typing import ( - Sequence, TYPE_CHECKING, Any, - TypeVar, + Type, Union, ) +from web3.datastructures import ( + NamedElementOnion, +) + if TYPE_CHECKING: from web3 import ( # noqa: F401 AsyncWeb3, Web3, ) - from web3.types import ( + from web3.types import ( # noqa: F401 + AsyncMakeRequestFn, + MakeRequestFn, RPCEndpoint, RPCResponse, ) -WEB3 = TypeVar("WEB3", "AsyncWeb3", "Web3") - - class Web3Middleware: """ Base class for web3.py middleware. This class is not meant to be used directly, but instead inherited from. """ - _w3: WEB3 + _w3: Union["AsyncWeb3", "Web3"] - def __init__(self, w3: WEB3) -> None: + def __init__(self, w3: Union["AsyncWeb3", "Web3"]) -> None: self._w3 = w3 # -- sync -- # - def _wrap_make_request(self, make_request): + def _wrap_make_request(self, make_request: "MakeRequestFn") -> "MakeRequestFn": def middleware(method: "RPCEndpoint", params: Any) -> "RPCResponse": method, params = self.request_processor(method, params) return self.response_processor(method, make_request(method, params)) @@ -50,7 +53,9 @@ def response_processor(self, method: "RPCEndpoint", response: "RPCResponse") -> # -- async -- # - async def _async_wrap_make_request(self, make_request): + async def _async_wrap_make_request( + self, make_request: "AsyncMakeRequestFn" + ) -> "AsyncMakeRequestFn": async def middleware(method: "RPCEndpoint", params: Any) -> "RPCResponse": method, params = await self.async_request_processor(method, params) return await self.async_response_processor( @@ -82,7 +87,7 @@ def build( w3: Union["AsyncWeb3", "Web3"], *args: Any, **kwargs: Any, - ): + ) -> Web3Middleware: """ Implementation should initialize the middleware class that implements it, load it with any of the necessary properties that it needs for processing, @@ -116,3 +121,9 @@ def response_processor(self, method, response): ``` """ raise NotImplementedError("Must be implemented by subclasses") + + +# --- type definitions --- # + +Middleware = Type[Web3Middleware] +MiddlewareOnion = NamedElementOnion[str, Middleware] diff --git a/web3/middleware/buffered_gas_estimate.py b/web3/middleware/buffered_gas_estimate.py index 11a3dede40..bb9ab509e6 100644 --- a/web3/middleware/buffered_gas_estimate.py +++ b/web3/middleware/buffered_gas_estimate.py @@ -1,6 +1,7 @@ from typing import ( TYPE_CHECKING, Any, + cast, ) from eth_utils.toolz import ( @@ -39,7 +40,7 @@ def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: transaction = assoc( transaction, "gas", - hex(get_buffered_gas_estimate(self._w3, transaction)), + hex(get_buffered_gas_estimate(cast("Web3", self._w3), transaction)), ) params = (transaction,) return method, params @@ -51,7 +52,7 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A transaction = params[0] if "gas" not in transaction: gas_estimate = await async_get_buffered_gas_estimate( - self._w3, transaction + cast("AsyncWeb3", self._w3), transaction ) transaction = assoc(transaction, "gas", hex(gas_estimate)) params = (transaction,) diff --git a/web3/middleware/filter.py b/web3/middleware/filter.py index 8e30d4027d..42afa5dc60 100644 --- a/web3/middleware/filter.py +++ b/web3/middleware/filter.py @@ -46,17 +46,23 @@ Web3Middleware, ) from web3.types import ( - Coroutine, + AsyncMakeRequestFn, FilterParams, LatestBlockParam, LogReceipt, - RPCEndpoint, - RPCResponse, + MakeRequestFn, _Hash32, ) if TYPE_CHECKING: - from web3 import Web3 # noqa: F401 + from web3 import ( # noqa: F401 + AsyncWeb3, + Web3, + ) + from web3.types import ( # noqa: F401 + RPCEndpoint, + RPCResponse, + ) if "WEB3_MAX_BLOCK_REQUEST" in os.environ: MAX_BLOCK_REQUEST = to_int(text=os.environ["WEB3_MAX_BLOCK_REQUEST"]) @@ -344,7 +350,7 @@ def block_hashes_in_range( async def async_iter_latest_block( - w3: "Web3", to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None + w3: "AsyncWeb3", to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None ) -> AsyncIterable[BlockNumber]: """Returns a generator that dispenses the latest block, if any new blocks have been mined since last iteration. @@ -367,9 +373,9 @@ async def async_iter_latest_block( is_bounded_range = to_block is not None and to_block != "latest" while True: - latest_block = await w3.eth.block_number # type: ignore + latest_block = await w3.eth.block_number # type ignored b/c is_bounded_range prevents unsupported comparison - if is_bounded_range and latest_block > to_block: + if is_bounded_range and latest_block > cast(int, to_block): yield None # No new blocks since last iteration. if _last is not None and _last == latest_block: @@ -380,7 +386,7 @@ async def async_iter_latest_block( async def async_iter_latest_block_ranges( - w3: "Web3", + w3: "AsyncWeb3", from_block: BlockNumber, to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None, ) -> AsyncIterable[Tuple[Optional[BlockNumber], Optional[BlockNumber]]]: @@ -410,7 +416,7 @@ async def async_iter_latest_block_ranges( async def async_get_logs_multipart( - w3: "Web3", + w3: "AsyncWeb3", start_block: BlockNumber, stop_block: BlockNumber, address: Union[Address, ChecksumAddress, List[Union[Address, ChecksumAddress]]], @@ -433,7 +439,7 @@ async def async_get_logs_multipart( params_with_none_dropped = cast( FilterParams, drop_items_with_none_value(params) ) - next_logs = await w3.eth.get_logs(params_with_none_dropped) # type: ignore + next_logs = await w3.eth.get_logs(params_with_none_dropped) yield next_logs @@ -442,7 +448,7 @@ class AsyncRequestLogs: def __init__( self, - w3: "Web3", + w3: "AsyncWeb3", from_block: Optional[Union[BlockNumber, LatestBlockParam]] = None, to_block: Optional[Union[BlockNumber, LatestBlockParam]] = None, address: Optional[ @@ -460,7 +466,7 @@ def __init__( def __await__(self) -> Generator[Any, None, "AsyncRequestLogs"]: async def closure() -> "AsyncRequestLogs": if self._from_block_arg is None or self._from_block_arg == "latest": - self.block_number = await self.w3.eth.block_number # type: ignore + self.block_number = await self.w3.eth.block_number self._from_block = BlockNumber(self.block_number + 1) elif is_string(self._from_block_arg) and is_hex(self._from_block_arg): self._from_block = BlockNumber( @@ -480,7 +486,7 @@ async def from_block(self) -> BlockNumber: @property async def to_block(self) -> BlockNumber: if self._to_block is None or self._to_block == "latest": - to_block = await self.w3.eth.block_number # type: ignore + to_block = await self.w3.eth.block_number elif is_string(self._to_block) and is_hex(self._to_block): to_block = BlockNumber(hex_to_integer(cast(HexStr, self._to_block))) else: @@ -528,12 +534,12 @@ async def get_logs(self) -> List[LogReceipt]: class AsyncRequestBlocks: - def __init__(self, w3: "Web3") -> None: + def __init__(self, w3: "AsyncWeb3") -> None: self.w3 = w3 def __await__(self) -> Generator[Any, None, "AsyncRequestBlocks"]: async def closure() -> "AsyncRequestBlocks": - self.block_number = await self.w3.eth.block_number # type: ignore + self.block_number = await self.w3.eth.block_number self.start_block = BlockNumber(self.block_number + 1) return self @@ -553,7 +559,7 @@ async def get_filter_changes(self) -> AsyncIterator[List[Hash32]]: async def async_block_hashes_in_range( - w3: "Web3", block_range: Tuple[BlockNumber, BlockNumber] + w3: "AsyncWeb3", block_range: Tuple[BlockNumber, BlockNumber] ) -> List[Union[None, Hash32]]: from_block, to_block = block_range if from_block is None or to_block is None: @@ -561,7 +567,7 @@ async def async_block_hashes_in_range( block_hashes = [] for block_number in range(from_block, to_block + 1): - w3_get_block = await w3.eth.get_block(BlockNumber(block_number)) # type: ignore + w3_get_block = await w3.eth.get_block(BlockNumber(block_number)) block_hashes.append(getattr(w3_get_block, "hash", None)) return block_hashes @@ -569,27 +575,31 @@ async def async_block_hashes_in_range( # -- middleware -- # +SyncFilter = Union[RequestLogs, RequestBlocks] +AsyncFilter = Union[AsyncRequestLogs, AsyncRequestBlocks] + class LocalFilterMiddleware(Web3Middleware): - def __init__(self, w3): - self.filters = {} + def __init__(self, w3: Union["Web3", "AsyncWeb3"]): + self.filters: Dict[str, SyncFilter] = {} + self.async_filters: Dict[str, AsyncFilter] = {} self.filter_id_counter = itertools.count() super().__init__(w3) - def _wrap_make_request(self, make_request): - def middleware(method, params): + def _wrap_make_request(self, make_request: MakeRequestFn) -> MakeRequestFn: + def middleware(method: "RPCEndpoint", params: Any) -> "RPCResponse": if method in NEW_FILTER_METHODS: + _filter: SyncFilter filter_id = to_hex(next(self.filter_id_counter)) - _filter: Union[RequestLogs, RequestBlocks] - if method == RPC.eth_newFilter: _filter = RequestLogs( - self._w3, **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) + cast("Web3", self._w3), + **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) ) elif method == RPC.eth_newBlockFilter: - _filter = RequestBlocks(self._w3) + _filter = RequestBlocks(cast("Web3", self._w3)) else: raise NotImplementedError(method) @@ -621,42 +631,45 @@ def middleware(method, params): # -- async -- # - async def _async_wrap_make_request(self, make_request): - async def middleware(method, params): + async def _async_wrap_make_request( + self, + make_request: AsyncMakeRequestFn, + ) -> AsyncMakeRequestFn: + async def middleware(method: "RPCEndpoint", params: Any) -> "RPCResponse": if method in NEW_FILTER_METHODS: + _filter: AsyncFilter filter_id = to_hex(next(self.filter_id_counter)) - _filter: Union[AsyncRequestLogs, AsyncRequestBlocks] - if method == RPC.eth_newFilter: _filter = await AsyncRequestLogs( - self._w3, **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) + cast("AsyncWeb3", self._w3), + **apply_key_map(FILTER_PARAMS_KEY_MAP, params[0]) ) elif method == RPC.eth_newBlockFilter: - _filter = await AsyncRequestBlocks(self._w3) + _filter = await AsyncRequestBlocks(cast("AsyncWeb3", self._w3)) else: raise NotImplementedError(method) - self.filters[filter_id] = _filter + self.async_filters[filter_id] = _filter return {"result": filter_id} elif method in FILTER_CHANGES_METHODS: _filter_id = params[0] # Pass through to filters not created by middleware - if _filter_id not in self.filters: + if _filter_id not in self.async_filters: return await make_request(method, params) - _filter = self.filters[_filter_id] + _filter = self.async_filters[_filter_id] if method == RPC.eth_getFilterChanges: return {"result": await _filter.filter_changes.__anext__()} elif method == RPC.eth_getFilterLogs: # type ignored b/c logic prevents RequestBlocks which # doesn't implement get_logs - return {"result": await _filter.get_logs()} + return {"result": await _filter.get_logs()} # type: ignore else: raise NotImplementedError(method) else: diff --git a/web3/middleware/formatting.py b/web3/middleware/formatting.py index 4870272d9b..f3b6dcd021 100644 --- a/web3/middleware/formatting.py +++ b/web3/middleware/formatting.py @@ -31,6 +31,9 @@ AsyncWeb3, Web3, ) + from web3.middleware.base import ( # noqa: F401 + Web3Middleware, + ) from web3.providers import ( # noqa: F401 PersistentConnectionProvider, ) @@ -110,7 +113,7 @@ def build( # formatters builder option: sync_formatters_builder: Optional[SYNC_FORMATTERS_BUILDER] = None, async_formatters_builder: Optional[ASYNC_FORMATTERS_BUILDER] = None, - ): + ) -> "FormattingMiddlewareBuilder": # if not both sync and async formatters are specified, raise error if ( sync_formatters_builder is None and async_formatters_builder is not None @@ -141,7 +144,7 @@ def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if self.sync_formatters_builder is not None: formatters = merge( FORMATTER_DEFAULTS, - self.sync_formatters_builder(self._w3, method), + self.sync_formatters_builder(cast("Web3", self._w3), method), ) self.request_formatters = formatters.pop("request_formatters") @@ -155,7 +158,7 @@ def response_processor(self, method: RPCEndpoint, response: "RPCResponse") -> An if self.sync_formatters_builder is not None: formatters = merge( FORMATTER_DEFAULTS, - self.sync_formatters_builder(self._w3, method), + self.sync_formatters_builder(cast("Web3", self._w3), method), ) self.result_formatters = formatters["result_formatters"] self.error_formatters = formatters["error_formatters"] @@ -173,7 +176,9 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A if self.async_formatters_builder is not None: formatters = merge( FORMATTER_DEFAULTS, - await self.async_formatters_builder(self._w3, method), + await self.async_formatters_builder( + cast("AsyncWeb3", self._w3), method + ), ) self.request_formatters = formatters.pop("request_formatters") @@ -189,7 +194,9 @@ async def async_response_processor( if self.async_formatters_builder is not None: formatters = merge( FORMATTER_DEFAULTS, - await self.async_formatters_builder(self._w3, method), + await self.async_formatters_builder( + cast("AsyncWeb3", self._w3), method + ), ) self.result_formatters = formatters["result_formatters"] self.error_formatters = formatters["error_formatters"] diff --git a/web3/middleware/gas_price_strategy.py b/web3/middleware/gas_price_strategy.py index 75e581d1b2..e259b5e054 100644 --- a/web3/middleware/gas_price_strategy.py +++ b/web3/middleware/gas_price_strategy.py @@ -1,6 +1,7 @@ from typing import ( TYPE_CHECKING, Any, + cast, ) from eth_utils.toolz import ( @@ -91,7 +92,8 @@ def request_processor(self, method: RPCEndpoint, params: Any) -> Any: if method == "eth_sendTransaction": transaction = params[0] generated_gas_price = self._w3.eth.generate_gas_price(transaction) - latest_block = self._w3.eth.get_block("latest") + w3 = cast("Web3", self._w3) + latest_block = w3.eth.get_block("latest") transaction = validate_transaction_params( transaction, latest_block, generated_gas_price ) @@ -104,8 +106,9 @@ def request_processor(self, method: RPCEndpoint, params: Any) -> Any: async def async_request_processor(self, method: RPCEndpoint, params: Any) -> Any: if method == "eth_sendTransaction": transaction = params[0] - generated_gas_price = self._w3.eth.generate_gas_price(transaction) - latest_block = await self._w3.eth.get_block("latest") + w3 = cast("AsyncWeb3", self._w3) + generated_gas_price = w3.eth.generate_gas_price(transaction) + latest_block = await w3.eth.get_block("latest") transaction = validate_transaction_params( transaction, latest_block, generated_gas_price ) diff --git a/web3/middleware/names.py b/web3/middleware/names.py index f9e49998c3..486f6fa9e2 100644 --- a/web3/middleware/names.py +++ b/web3/middleware/names.py @@ -4,6 +4,7 @@ Dict, Sequence, Union, + cast, ) from toolz import ( @@ -106,7 +107,7 @@ def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: abi_ens_resolver(self._w3), ] self._formatting_middleware = FormattingMiddlewareBuilder.build( - request_formatters=abi_request_formatters(normalizers, RPC_ABIS) # type: ignore # noqa: E501 + request_formatters=abi_request_formatters(normalizers, RPC_ABIS) ) return self._formatting_middleware(self._w3).request_processor(method, params) @@ -121,7 +122,7 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A # eth_subscribe optional logs params are unique. # Handle them separately here. (formatted_dict,) = await async_apply_ens_to_address_conversion( - self._w3, + cast("AsyncWeb3", self._w3), (params[1],), { "address": "address", @@ -132,7 +133,7 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A else: params = await async_apply_ens_to_address_conversion( - self._w3, + cast("AsyncWeb3", self._w3), params, abi_types_for_method, ) diff --git a/web3/middleware/signing.py b/web3/middleware/signing.py index 9cd7c223c0..2d35932e2b 100644 --- a/web3/middleware/signing.py +++ b/web3/middleware/signing.py @@ -10,10 +10,9 @@ Tuple, TypeVar, Union, + cast, ) -from toolz import curry - from eth_account import ( Account, ) @@ -37,6 +36,9 @@ from eth_utils.toolz import ( compose, ) +from toolz import ( + curry, +) from web3._utils.async_transactions import ( async_fill_nonce, @@ -54,7 +56,6 @@ fill_transaction_defaults, ) from web3.middleware.base import ( - Web3Middleware, Web3MiddlewareBuilder, ) from web3.types import ( @@ -148,7 +149,10 @@ class SignAndSendRawMiddlewareBuilder(Web3MiddlewareBuilder): @staticmethod @curry - def build(private_key_or_account: Union[_PrivateKey, Collection[_PrivateKey]], w3): + def build( + private_key_or_account: Union[_PrivateKey, Collection[_PrivateKey]], + w3: Union["Web3", "AsyncWeb3"], + ) -> "SignAndSendRawMiddlewareBuilder": middleware = SignAndSendRawMiddlewareBuilder(w3) middleware._accounts = gen_normalized_accounts(private_key_or_account) return middleware @@ -157,11 +161,12 @@ def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method != "eth_sendTransaction": return method, params else: + w3 = cast("Web3", self._w3) if self.format_and_fill_tx is None: self.format_and_fill_tx = compose( format_transaction, - fill_transaction_defaults(self._w3), - fill_nonce(self._w3), + fill_transaction_defaults(w3), + fill_nonce(w3), ) filled_transaction = self.format_and_fill_tx(params[0]) @@ -187,15 +192,13 @@ async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> A return method, params else: + w3 = cast("AsyncWeb3", self._w3) + formatted_transaction = format_transaction(params[0]) filled_transaction = await async_fill_transaction_defaults( - self._w3, - formatted_transaction, - ) - filled_transaction = await async_fill_nonce( - self._w3, - filled_transaction, + w3, formatted_transaction ) + filled_transaction = await async_fill_nonce(w3, filled_transaction) tx_from = filled_transaction.get("from", None) if tx_from is None or ( diff --git a/web3/middleware/stalecheck.py b/web3/middleware/stalecheck.py index 4f8b98391d..e5eaedefa1 100644 --- a/web3/middleware/stalecheck.py +++ b/web3/middleware/stalecheck.py @@ -6,9 +6,13 @@ Collection, Dict, Optional, + Union, + cast, ) -from toolz import curry +from toolz import ( + curry, +) from web3.exceptions import ( StaleBlockchain, @@ -46,7 +50,7 @@ class StaleCheckMiddlewareBuilder(Web3MiddlewareBuilder): @curry def build( allowable_delay: int, - w3, + w3: Union["Web3", "AsyncWeb3"], skip_stalecheck_for_methods: Collection[str] = SKIP_STALECHECK_FOR_METHODS, ) -> Web3Middleware: if allowable_delay <= 0: @@ -62,7 +66,9 @@ def build( def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method not in self.skip_stalecheck_for_methods: if not _is_fresh(self.cache["latest"], self.allowable_delay): - latest = self._w3.eth.get_block("latest") + w3 = cast("Web3", self._w3) + latest = w3.eth.get_block("latest") + if _is_fresh(latest, self.allowable_delay): self.cache["latest"] = latest else: @@ -75,7 +81,9 @@ def request_processor(self, method: "RPCEndpoint", params: Any) -> Any: async def async_request_processor(self, method: "RPCEndpoint", params: Any) -> Any: if method not in self.skip_stalecheck_for_methods: if not _is_fresh(self.cache["latest"], self.allowable_delay): - latest = await self._w3.eth.get_block("latest") + w3 = cast("AsyncWeb3", self._w3) + latest = await w3.eth.get_block("latest") + if _is_fresh(latest, self.allowable_delay): self.cache["latest"] = latest else: diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py index 4a10b0671c..19a3e6a37a 100644 --- a/web3/providers/async_base.py +++ b/web3/providers/async_base.py @@ -26,11 +26,10 @@ async_combine_middlewares, ) from web3.middleware.base import ( - Web3Middleware, + Middleware, + MiddlewareOnion, ) from web3.types import ( - AsyncMiddlewareOnion, - MiddlewareOnion, RPCEndpoint, RPCResponse, ) @@ -63,9 +62,8 @@ class AsyncBaseProvider: - _middlewares: Tuple[Web3Middleware, ...] = () _request_func_cache: Tuple[ - Tuple[Web3Middleware, ...], Callable[..., Coroutine[Any, Any, RPCResponse]] + Tuple[Middleware, ...], Callable[..., Coroutine[Any, Any, RPCResponse]] ] = (None, None) is_async = True @@ -81,27 +79,18 @@ class AsyncBaseProvider: def __init__(self) -> None: self._request_cache = SimpleCache(1000) - @property - def middlewares(self) -> Tuple[Web3Middleware, ...]: - return self._middlewares - - @middlewares.setter - def middlewares(self, values: MiddlewareOnion) -> None: - # tuple(values) converts to MiddlewareOnion -> Tuple[Middleware, ...] - self._middlewares = tuple(values) # type: ignore - async def request_func( - self, async_w3: "AsyncWeb3", outer_middlewares: AsyncMiddlewareOnion + self, async_w3: "AsyncWeb3", middlewares: MiddlewareOnion ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: # type ignored b/c tuple(MiddlewareOnion) converts to tuple of middlewares - all_middlewares: Tuple[Web3Middleware] = tuple(outer_middlewares) + tuple(self.middlewares) # type: ignore # noqa: E501 + middlewares: Tuple[Middleware, ...] = tuple(middlewares) # type: ignore cache_key = self._request_func_cache[0] - if cache_key != all_middlewares: + if cache_key != middlewares: # type: ignore self._request_func_cache = ( - all_middlewares, + middlewares, await async_combine_middlewares( - middlewares=all_middlewares, + middlewares=middlewares, # type: ignore async_w3=async_w3, provider_request_fn=self.make_request, ), diff --git a/web3/providers/base.py b/web3/providers/base.py index a0e59d8eca..a91e9d8708 100644 --- a/web3/providers/base.py +++ b/web3/providers/base.py @@ -21,12 +21,13 @@ ProviderConnectionError, ) from web3.middleware import ( - Web3Middleware, combine_middlewares, ) -from web3.types import ( +from web3.middleware.base import ( Middleware, MiddlewareOnion, +) +from web3.types import ( RPCEndpoint, RPCResponse, ) @@ -56,10 +57,11 @@ class BaseProvider: - _middlewares: Tuple[Middleware, ...] = () - _request_func_cache: Tuple[ - Tuple[Web3Middleware, ...], Callable[..., RPCResponse] - ] = (None, None) + # a tuple of (middlewares, request_func) + _request_func_cache: Tuple[Tuple[Middleware, ...], Callable[..., RPCResponse]] = ( + None, + None, + ) is_async = False has_persistent_connection = False @@ -74,34 +76,25 @@ class BaseProvider: def __init__(self) -> None: self._request_cache = SimpleCache(1000) - @property - def middlewares(self) -> Tuple[Web3Middleware, ...]: - return self._middlewares - - @middlewares.setter - def middlewares(self, values: MiddlewareOnion) -> None: - # tuple(values) converts to MiddlewareOnion -> Tuple[Middleware, ...] - self._middlewares = tuple(values) # type: ignore - def request_func( - self, w3: "Web3", outer_middlewares: MiddlewareOnion + self, w3: "Web3", middlewares: MiddlewareOnion ) -> Callable[..., RPCResponse]: """ @param w3 is the web3 instance - @param outer_middlewares is an iterable of middlewares, + @param middlewares is an iterable of middlewares, ordered by first to execute @returns a function that calls all the middleware and eventually self.make_request() """ # type ignored b/c tuple(MiddlewareOnion) converts to tuple of middlewares - all_middlewares: Tuple[Web3Middleware] = tuple(outer_middlewares) + tuple(self.middlewares) # type: ignore # noqa: E501 + middlewares: Tuple[Middleware] = tuple(middlewares) # type: ignore cache_key = self._request_func_cache[0] - if cache_key != all_middlewares: + if cache_key != middlewares: # type: ignore self._request_func_cache = ( - all_middlewares, + middlewares, combine_middlewares( - middlewares=all_middlewares, + middlewares=middlewares, # type: ignore w3=w3, provider_request_fn=self.make_request, ), diff --git a/web3/providers/eth_tester/main.py b/web3/providers/eth_tester/main.py index 217e319efa..45b95718c8 100644 --- a/web3/providers/eth_tester/main.py +++ b/web3/providers/eth_tester/main.py @@ -2,6 +2,7 @@ TYPE_CHECKING, Any, Callable, + Coroutine, Dict, Optional, Union, @@ -20,9 +21,6 @@ from web3._utils.compat import ( Literal, ) -from web3.middleware.attrdict import ( - attrdict_middleware, -) from web3.providers import ( BaseProvider, ) @@ -35,6 +33,10 @@ RPCResponse, ) +from ...middleware import ( + async_combine_middlewares, + combine_middlewares, +) from .middleware import ( default_transaction_fields_middleware, ethereum_tester_middleware, @@ -44,9 +46,19 @@ from eth_tester import EthereumTester # noqa: F401 from eth_tester.backends.base import BaseChainBackend # noqa: F401 + from web3 import ( # noqa: F401 + AsyncWeb3, + Web3, + ) + from web3.middleware.base import ( # noqa: F401 + Middleware, + MiddlewareOnion, + Web3Middleware, + ) + class AsyncEthereumTesterProvider(AsyncBaseProvider): - middlewares = ( + _middlewares = ( default_transaction_fields_middleware, ethereum_tester_middleware, ) @@ -66,6 +78,26 @@ def __init__(self) -> None: self.ethereum_tester = EthereumTester() self.api_endpoints = API_ENDPOINTS + async def request_func( + self, async_w3: "AsyncWeb3", middlewares: "MiddlewareOnion" + ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: + # override the request_func to add the ethereum_tester_middleware + + # type ignored bc tuple(MiddlewareOnion) converts to tuple of middlewares + middlewares = tuple(middlewares) + tuple(self._middlewares) # type: ignore + + cache_key = self._request_func_cache[0] + if cache_key != middlewares: # type: ignore + self._request_func_cache = ( + middlewares, + await async_combine_middlewares( + middlewares=middlewares, # type: ignore + async_w3=async_w3, + provider_request_fn=self.make_request, + ), + ) + return self._request_func_cache[-1] + async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: return _make_request(method, params, self.api_endpoints, self.ethereum_tester) @@ -74,7 +106,7 @@ async def is_connected(self, show_traceback: bool = False) -> Literal[True]: class EthereumTesterProvider(BaseProvider): - middlewares = ( + _middlewares = ( default_transaction_fields_middleware, ethereum_tester_middleware, ) @@ -121,6 +153,26 @@ def __init__( else: self.api_endpoints = api_endpoints + def request_func( + self, w3: "Web3", middlewares: "MiddlewareOnion" + ) -> Callable[..., RPCResponse]: + # override the request_func to add the ethereum_tester_middleware + + # type ignored bc tuple(MiddlewareOnion) converts to tuple of middlewares + middlewares = tuple(middlewares) + tuple(self._middlewares) # type: ignore + + cache_key = self._request_func_cache[0] + if cache_key != middlewares: # type: ignore + self._request_func_cache = ( + middlewares, + combine_middlewares( + middlewares=middlewares, # type: ignore + w3=w3, + provider_request_fn=self.make_request, + ), + ) + return self._request_func_cache[-1] + def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: return _make_request(method, params, self.api_endpoints, self.ethereum_tester) diff --git a/web3/providers/rpc/async_rpc.py b/web3/providers/rpc/async_rpc.py index e2279c72b1..1fcd6065ef 100644 --- a/web3/providers/rpc/async_rpc.py +++ b/web3/providers/rpc/async_rpc.py @@ -28,7 +28,6 @@ get_default_http_endpoint, ) from web3.types import ( - AsyncMiddleware, RPCEndpoint, RPCResponse, ) @@ -36,9 +35,6 @@ from ..._utils.caching import ( async_handle_request_caching, ) -from ...datastructures import ( - NamedElementOnion, -) from ..async_base import ( AsyncJSONBaseProvider, ) @@ -52,8 +48,6 @@ class AsyncHTTPProvider(AsyncJSONBaseProvider): logger = logging.getLogger("web3.providers.AsyncHTTPProvider") endpoint_uri = None _request_kwargs = None - # type ignored b/c conflict with _middlewares attr on AsyncBaseProvider - _middlewares: Tuple[AsyncMiddleware, ...] = NamedElementOnion([]) # type: ignore def __init__( self, @@ -119,6 +113,7 @@ async def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes continue else: raise + return None else: return await async_make_post_request( self.endpoint_uri, request_data, **self.get_request_kwargs() diff --git a/web3/providers/rpc/rpc.py b/web3/providers/rpc/rpc.py index 92238c72f0..c2c9139de9 100644 --- a/web3/providers/rpc/rpc.py +++ b/web3/providers/rpc/rpc.py @@ -1,6 +1,7 @@ import logging import time from typing import ( + TYPE_CHECKING, Any, Dict, Iterable, @@ -24,11 +25,7 @@ get_default_http_endpoint, make_post_request, ) -from web3.datastructures import ( - NamedElementOnion, -) from web3.types import ( - Middleware, RPCEndpoint, RPCResponse, ) @@ -44,6 +41,11 @@ check_if_retry_on_failure, ) +if TYPE_CHECKING: + from web3.middleware.base import ( # noqa: F401 + Middleware, + ) + class HTTPProvider(JSONBaseProvider): logger = logging.getLogger("web3.providers.HTTPProvider") @@ -51,8 +53,6 @@ class HTTPProvider(JSONBaseProvider): _request_args = None _request_kwargs = None - # type ignored b/c conflict with _middlewares attr on BaseProvider - _middlewares: Tuple[Middleware, ...] = NamedElementOnion([]) # type: ignore # noqa: E501 exception_retry_configuration: Optional[ExceptionRetryConfiguration] = None @@ -110,12 +110,13 @@ def _make_request(self, method: RPCEndpoint, request_data: bytes) -> bytes: return make_post_request( self.endpoint_uri, request_data, **self.get_request_kwargs() ) - except tuple(self.exception_retry_configuration.errors): + except tuple(self.exception_retry_configuration.errors) as e: if i < self.exception_retry_configuration.retries - 1: time.sleep(self.exception_retry_configuration.backoff_factor) continue else: - raise + raise e + return None else: return make_post_request( self.endpoint_uri, request_data, **self.get_request_kwargs() diff --git a/web3/providers/rpc/utils.py b/web3/providers/rpc/utils.py index 4236b01523..e20fbc44f8 100644 --- a/web3/providers/rpc/utils.py +++ b/web3/providers/rpc/utils.py @@ -1,5 +1,4 @@ from typing import ( - Optional, Sequence, Type, ) diff --git a/web3/providers/websocket/websocket_v2.py b/web3/providers/websocket/websocket_v2.py index bbee3597cd..3189c6998a 100644 --- a/web3/providers/websocket/websocket_v2.py +++ b/web3/providers/websocket/websocket_v2.py @@ -40,8 +40,6 @@ RPCId, RPCResponse, ) -from web3.utils import SimpleCache - DEFAULT_PING_INTERVAL = 30 # 30 seconds DEFAULT_PING_TIMEOUT = 300 # 5 minutes diff --git a/web3/tools/benchmark/main.py b/web3/tools/benchmark/main.py index 8c40e12cee..bc96416970 100644 --- a/web3/tools/benchmark/main.py +++ b/web3/tools/benchmark/main.py @@ -24,8 +24,6 @@ Web3, ) from web3.middleware import ( - async_buffered_gas_estimate_middleware, - async_gas_price_strategy_middleware, buffered_gas_estimate_middleware, gas_price_strategy_middleware, ) @@ -74,10 +72,7 @@ async def build_async_w3_http(endpoint_uri: str) -> AsyncWeb3: await wait_for_aiohttp(endpoint_uri) _w3 = AsyncWeb3( AsyncHTTPProvider(endpoint_uri), - middlewares=[ - async_gas_price_strategy_middleware, - async_buffered_gas_estimate_middleware, - ], + middlewares=[gas_price_strategy_middleware, buffered_gas_estimate_middleware], ) return _w3 diff --git a/web3/types.py b/web3/types.py index 8dfc1fbdf3..602364edd7 100644 --- a/web3/types.py +++ b/web3/types.py @@ -33,9 +33,6 @@ FallbackFn, ReceiveFn, ) -from web3.datastructures import ( - NamedElementOnion, -) if TYPE_CHECKING: from web3.contract.async_contract import AsyncContractFunction # noqa: F401 @@ -316,15 +313,8 @@ class CreateAccessListResponse(TypedDict): gasUsed: int -Middleware = Callable[[Callable[[RPCEndpoint, Any], RPCResponse], "Web3"], Any] -AsyncMiddlewareCoroutine = Callable[ - [RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse] -] -AsyncMiddleware = Callable[ - [Callable[[RPCEndpoint, Any], RPCResponse], "AsyncWeb3"], Any -] -MiddlewareOnion = NamedElementOnion[str, Middleware] -AsyncMiddlewareOnion = NamedElementOnion[str, AsyncMiddleware] +MakeRequestFn = Callable[[RPCEndpoint, Any], RPCResponse] +AsyncMakeRequestFn = Callable[[RPCEndpoint, Any], Coroutine[Any, Any, RPCResponse]] class FormattersDict(TypedDict, total=False): From 81decfb935f1189c1464b7a84b3665faba07892c Mon Sep 17 00:00:00 2001 From: fselmo Date: Thu, 14 Dec 2023 17:46:44 -0700 Subject: [PATCH 15/24] Fix inconsistency with eth-tester integration test --- tests/integration/test_ethereum_tester.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/integration/test_ethereum_tester.py b/tests/integration/test_ethereum_tester.py index cc5d8a4fbc..dfcc3fc97c 100644 --- a/tests/integration/test_ethereum_tester.py +++ b/tests/integration/test_ethereum_tester.py @@ -266,14 +266,6 @@ def func_wrapper(self, eth_tester, *args, **kwargs): class TestEthereumTesterEthModule(EthModuleTest): - test_eth_max_priority_fee_with_fee_history_calculation = not_implemented( - EthModuleTest.test_eth_max_priority_fee_with_fee_history_calculation, - MethodUnavailable, - ) - test_eth_max_priority_fee_with_fee_history_calculation_error_dict = not_implemented( - EthModuleTest.test_eth_max_priority_fee_with_fee_history_calculation_error_dict, - ValueError, - ) test_eth_sign = not_implemented(EthModuleTest.test_eth_sign, MethodUnavailable) test_eth_sign_ens_names = not_implemented( EthModuleTest.test_eth_sign_ens_names, MethodUnavailable From 5f2f1fb64dc9b129396b9bce75a626885edb7b17 Mon Sep 17 00:00:00 2001 From: fselmo Date: Fri, 15 Dec 2023 12:04:33 -0700 Subject: [PATCH 16/24] Remove unnecessary imports hidden by # noqa: F401 --- ens/ens.py | 1 - ens/utils.py | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/ens/ens.py b/ens/ens.py index d8e334678a..fcc0ecdb4f 100644 --- a/ens/ens.py +++ b/ens/ens.py @@ -78,7 +78,6 @@ BaseProvider, ) from web3.types import ( # noqa: F401 - MakeRequestFn, TxParams, ) diff --git a/ens/utils.py b/ens/utils.py index 9cff648e11..8a17fe8af7 100644 --- a/ens/utils.py +++ b/ens/utils.py @@ -59,7 +59,7 @@ AsyncWeb3, Web3 as _Web3, ) - from web3.middleware.base import ( # noqa: F401 + from web3.middleware.base import ( Middleware, ) from web3.providers import ( # noqa: F401 @@ -68,7 +68,6 @@ ) from web3.types import ( # noqa: F401 ABIFunction, - RPCEndpoint, ) From fd13864a426d52e4b79061880557b9241ad41a03 Mon Sep 17 00:00:00 2001 From: fselmo Date: Fri, 15 Dec 2023 12:39:01 -0700 Subject: [PATCH 17/24] use request mocker + no need to format logsBloom in eth-tester middleware - Use request_mocker where we can - result_processors take care of the formatting... this was causing the ``None`` value for logsBloom to raise when mocked --- tests/core/eth-module/test_block_api.py | 20 ++++++++------------ web3/providers/eth_tester/middleware.py | 7 ++----- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/tests/core/eth-module/test_block_api.py b/tests/core/eth-module/test_block_api.py index 3b64161fac..2e4edf2808 100644 --- a/tests/core/eth-module/test_block_api.py +++ b/tests/core/eth-module/test_block_api.py @@ -1,7 +1,4 @@ import pytest -from unittest.mock import ( - Mock, -) from eth_utils import ( to_checksum_address, @@ -22,7 +19,7 @@ def test_uses_default_block(w3, extra_accounts, wait_for_transaction): assert w3.eth.default_block == w3.eth.block_number -def test_get_block_formatters_with_null_values(w3, mocker): +def test_get_block_formatters_with_null_values(w3, request_mocker): null_values_block = { "baseFeePerGas": None, "extraData": None, @@ -47,13 +44,12 @@ def test_get_block_formatters_with_null_values(w3, mocker): "withdrawalsRoot": None, "withdrawals": [], } - mocker.patch("web3.eth.eth.Eth.get_block", return_value=null_values_block) - - received_block = w3.eth.get_block("pending") + with request_mocker(w3, mock_results={"eth_getBlockByNumber": null_values_block}): + received_block = w3.eth.get_block("pending") assert received_block == null_values_block -def test_get_block_formatters_with_pre_formatted_values(w3): +def test_get_block_formatters_with_pre_formatted_values(w3, request_mocker): unformatted_values_block = { "baseFeePerGas": "0x3b9aca00", "extraData": "0x", @@ -107,10 +103,10 @@ def test_get_block_formatters_with_pre_formatted_values(w3): ], } - w3.manager._make_request = Mock() - w3.manager._make_request.return_value = {"result": unformatted_values_block} - - received_block = w3.eth.get_block("pending") + with request_mocker( + w3, mock_results={"eth_getBlockByNumber": unformatted_values_block} + ): + received_block = w3.eth.get_block("pending") assert received_block == { "baseFeePerGas": int(unformatted_values_block["baseFeePerGas"], 16), diff --git a/web3/providers/eth_tester/middleware.py b/web3/providers/eth_tester/middleware.py index 453a62d081..bb8c8c3576 100644 --- a/web3/providers/eth_tester/middleware.py +++ b/web3/providers/eth_tester/middleware.py @@ -189,7 +189,6 @@ def is_hexstr(value: Any) -> bool: block_result_remapper = apply_key_map(BLOCK_RESULT_KEY_MAPPING) BLOCK_RESULT_FORMATTERS = { - "logsBloom": integer_to_hex, "withdrawals": apply_list_to_array_formatter( apply_key_map({"validator_index": "validatorIndex"}), ), @@ -270,12 +269,10 @@ def is_hexstr(value: Any) -> bool: } result_formatters: Optional[Dict[RPCEndpoint, Callable[..., Any]]] = { RPCEndpoint("eth_getBlockByHash"): apply_formatter_if( - is_dict, - compose(block_result_remapper, block_result_formatter), + is_dict, compose(block_result_remapper, block_result_formatter) ), RPCEndpoint("eth_getBlockByNumber"): apply_formatter_if( - is_dict, - compose(block_result_remapper, block_result_formatter), + is_dict, compose(block_result_remapper, block_result_formatter) ), RPCEndpoint("eth_getBlockTransactionCountByHash"): apply_formatter_if( is_dict, From 45d61cb48023ed62dc68e41da2149718f400c92a Mon Sep 17 00:00:00 2001 From: fselmo Date: Fri, 15 Dec 2023 12:53:55 -0700 Subject: [PATCH 18/24] Rename all references of geth_poa to extradata_to_poa - The middleware changed names in the refactor, but this commit puts the finishing touches and closes #899 --- docs/examples.rst | 4 +-- docs/middleware.rst | 27 ++++++++++--------- docs/providers.rst | 1 - tests/core/eth-module/test_poa.py | 2 +- web3/_utils/module_testing/eth_module.py | 2 +- .../persistent_connection_provider.py | 2 +- web3/middleware/proof_of_authority.py | 18 ++++++++----- web3/types.py | 2 +- 8 files changed, 32 insertions(+), 26 deletions(-) diff --git a/docs/examples.rst b/docs/examples.rst index 9260d81b8d..97099f3925 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -688,8 +688,8 @@ Inject the middleware into the middleware onion .. code-block:: python - from web3.middleware import geth_poa_middleware - w3.middleware_onion.inject(geth_poa_middleware, layer=0) + from web3.middleware import extradata_to_poa_middleware + w3.middleware_onion.inject(extradata_to_poa_middleware, layer=0) Just remember that you have to sign all transactions locally, as infura does not handle any keys from your wallet ( refer to `this`_ ) diff --git a/docs/middleware.rst b/docs/middleware.rst index baf3869982..1ff48d1b59 100644 --- a/docs/middleware.rst +++ b/docs/middleware.rst @@ -446,16 +446,15 @@ Time-based Cache Middleware Proof of Authority ~~~~~~~~~~~~~~~~~~ -.. py:method:: web3.middleware.geth_poa_middleware - web3.middleware.async_geth_poa_middleware +.. py:class:: web3.middleware.extradata_to_poa_middleware .. note:: It's important to inject the middleware at the 0th layer of the middleware onion: - ``w3.middleware_onion.inject(geth_poa_middleware, layer=0)`` + ``w3.middleware_onion.inject(extradata_to_poa_middleware, layer=0)`` -The ``geth_poa_middleware`` is required to connect to ``geth --dev`` or the Goerli -public network. It may also be needed for other EVM compatible blockchains like Polygon -or BNB Chain (Binance Smart Chain). +The ``extradata_to_poa_middleware`` is required to connect to ``geth --dev`` and may +also be needed for other EVM compatible blockchains like Polygon or +BNB Chain (Binance Smart Chain). If the middleware is not injected at the 0th layer of the middleware onion, you may get errors like the example below when interacting with your EVM node. @@ -468,7 +467,8 @@ errors like the example below when interacting with your EVM node. for more details. The full extraData is: HexBytes('...') -The easiest way to connect to a default ``geth --dev`` instance which loads the middleware is: +The easiest way to connect to a default ``geth --dev`` instance which loads the +middleware is: .. code-block:: python @@ -489,25 +489,26 @@ unique IPC location and loads the middleware: # connect to the IPC location started with 'geth --dev --datadir ~/mynode' >>> w3 = Web3(IPCProvider('~/mynode/geth.ipc')) - >>> from web3.middleware import geth_poa_middleware + >>> from web3.middleware import extradata_to_poa_middleware # inject the poa compatibility middleware to the innermost layer (0th layer) - >>> w3.middleware_onion.inject(geth_poa_middleware, layer=0) + >>> w3.middleware_onion.inject(extradata_to_poa_middleware, layer=0) # confirm that the connection succeeded >>> w3.client_version 'Geth/v1.7.3-stable-4bb3c89d/linux-amd64/go1.9' -Why is ``geth_poa_middleware`` necessary? -''''''''''''''''''''''''''''''''''''''''' +Why is ``extradata_to_poa_middleware`` necessary? +''''''''''''''''''''''''''''''''''''''''''''''''' There is no strong community consensus on a single Proof-of-Authority (PoA) standard yet. Some nodes have successful experiments running though. One is go-ethereum (geth), which uses a prototype PoA for its development mode and the Goerli test network. Unfortunately, it does deviate from the yellow paper specification, which constrains the -``extraData`` field in each block to a maximum of 32-bytes. Geth's PoA uses more than -32 bytes, so this middleware modifies the block data a bit before returning it. +``extraData`` field in each block to a maximum of 32-bytes. Geth is one such example +where PoA uses more than 32 bytes, so this middleware modifies the block data a bit +before returning it. .. _local-filter: diff --git a/docs/providers.rst b/docs/providers.rst index a87a68fca0..6f983ab769 100644 --- a/docs/providers.rst +++ b/docs/providers.rst @@ -467,7 +467,6 @@ AsyncHTTPProvider - :meth:`Attribute Dict Middleware ` - :meth:`Buffered Gas Estimate Middleware ` - :meth:`Gas Price Strategy Middleware ` - - :meth:`Geth POA Middleware ` - :meth:`Local Filter Middleware ` - :meth:`Simple Cache Middleware ` - :meth:`Stalecheck Middleware ` diff --git a/tests/core/eth-module/test_poa.py b/tests/core/eth-module/test_poa.py index 6e0c901458..6dfc678fc5 100644 --- a/tests/core/eth-module/test_poa.py +++ b/tests/core/eth-module/test_poa.py @@ -26,7 +26,7 @@ def test_full_extra_data(w3, request_mocker): assert block.extraData == b"\xff" * 32 -def test_geth_proof_of_authority(w3, request_mocker): +def test_extradata_to_poa_middleware(w3, request_mocker): w3.middleware_onion.inject(extradata_to_poa_middleware, layer=0) with request_mocker( diff --git a/web3/_utils/module_testing/eth_module.py b/web3/_utils/module_testing/eth_module.py index feee67ebd7..8fb8f89b31 100644 --- a/web3/_utils/module_testing/eth_module.py +++ b/web3/_utils/module_testing/eth_module.py @@ -650,7 +650,7 @@ async def test_validation_middleware_chain_id_mismatch( await async_w3.eth.send_transaction(txn_params) @pytest.mark.asyncio - async def test_geth_poa_middleware( + async def test_extradata_to_poa_middleware( self, async_w3: "AsyncWeb3", request_mocker: Type[RequestMocker] ) -> None: async_w3.middleware_onion.inject(extradata_to_poa_middleware, "poa", layer=0) diff --git a/web3/_utils/module_testing/persistent_connection_provider.py b/web3/_utils/module_testing/persistent_connection_provider.py index 7cf314122c..4c6ca65e86 100644 --- a/web3/_utils/module_testing/persistent_connection_provider.py +++ b/web3/_utils/module_testing/persistent_connection_provider.py @@ -316,7 +316,7 @@ async def _mocked_recv_coro() -> bytes: async_w3.provider._ws.__setattr__("recv", actual_recv_fxn) @pytest.mark.asyncio - async def test_async_geth_poa_middleware_on_eth_subscription( + async def test_async_extradata_to_poa_middleware_on_eth_subscription( self, async_w3: "_PersistentConnectionWeb3", ) -> None: diff --git a/web3/middleware/proof_of_authority.py b/web3/middleware/proof_of_authority.py index bfb4623591..0bcf1ee42d 100644 --- a/web3/middleware/proof_of_authority.py +++ b/web3/middleware/proof_of_authority.py @@ -34,29 +34,35 @@ is_not_null = complement(is_null) -remap_geth_poa_fields = apply_key_map( +remap_extradata_to_poa_fields = apply_key_map( { "extraData": "proofOfAuthorityData", } ) -pythonic_geth_poa = apply_formatters_to_dict( +pythonic_extradata_to_poa = apply_formatters_to_dict( { "proofOfAuthorityData": HexBytes, } ) -geth_poa_cleanup = compose(pythonic_geth_poa, remap_geth_poa_fields) +extradata_to_poa_cleanup = compose( + pythonic_extradata_to_poa, remap_extradata_to_poa_fields +) extradata_to_poa_middleware = FormattingMiddlewareBuilder.build( result_formatters={ - RPC.eth_getBlockByHash: apply_formatter_if(is_not_null, geth_poa_cleanup), - RPC.eth_getBlockByNumber: apply_formatter_if(is_not_null, geth_poa_cleanup), + RPC.eth_getBlockByHash: apply_formatter_if( + is_not_null, extradata_to_poa_cleanup + ), + RPC.eth_getBlockByNumber: apply_formatter_if( + is_not_null, extradata_to_poa_cleanup + ), RPC.eth_subscribe: apply_formatter_if( is_not_null, # original call to eth_subscribe returns a string, needs a dict check - apply_formatter_if(is_dict, geth_poa_cleanup), + apply_formatter_if(is_dict, extradata_to_poa_cleanup), ), }, ) diff --git a/web3/types.py b/web3/types.py index 602364edd7..413d0f0af3 100644 --- a/web3/types.py +++ b/web3/types.py @@ -219,7 +219,7 @@ class BlockData(TypedDict, total=False): withdrawals: Sequence[WithdrawalData] withdrawalsRoot: HexBytes - # geth_poa_middleware replaces extraData w/ proofOfAuthorityData + # extradata_to_poa_middleware replaces extraData w/ proofOfAuthorityData proofOfAuthorityData: HexBytes From f6a3ae9abb5ffe3d74c30273801e3c8f641e35e3 Mon Sep 17 00:00:00 2001 From: fselmo Date: Fri, 15 Dec 2023 13:04:39 -0700 Subject: [PATCH 19/24] Minor cleanups from refactor PR --- .../gas-strategies/test_time_based_gas_price_strategy.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/core/gas-strategies/test_time_based_gas_price_strategy.py b/tests/core/gas-strategies/test_time_based_gas_price_strategy.py index c94e4c1fdb..1b46f42039 100644 --- a/tests/core/gas-strategies/test_time_based_gas_price_strategy.py +++ b/tests/core/gas-strategies/test_time_based_gas_price_strategy.py @@ -181,14 +181,6 @@ def _get_gas_price(method, params): def test_time_based_gas_price_strategy_without_transactions(request_mocker): - # fixture_middleware = construct_result_generator_middleware( - # { - # "eth_getBlockByHash": _get_initial_block, - # "eth_getBlockByNumber": _get_initial_block, - # "eth_gasPrice": _get_gas_price, - # } - # ) - w3 = Web3(provider=BaseProvider()) time_based_gas_price_strategy = construct_time_based_gas_price_strategy( From 57d89211d75a5432642efddd9bc94f844d39960b Mon Sep 17 00:00:00 2001 From: fselmo Date: Mon, 18 Dec 2023 12:07:17 -0700 Subject: [PATCH 20/24] 'name_to_address' -> 'ens_name_to_address' for consistency. --- ens/utils.py | 6 +++--- tests/core/manager/test_default_middlewares.py | 2 +- tests/core/providers/test_async_http_provider.py | 2 +- tests/core/providers/test_http_provider.py | 4 +++- web3/manager.py | 2 +- 5 files changed, 9 insertions(+), 7 deletions(-) diff --git a/ens/utils.py b/ens/utils.py index 8a17fe8af7..314beac68c 100644 --- a/ens/utils.py +++ b/ens/utils.py @@ -103,8 +103,8 @@ def customize_web3(w3: "_Web3") -> "_Web3": make_stalecheck_middleware, ) - if w3.middleware_onion.get("name_to_address"): - w3.middleware_onion.remove("name_to_address") + if w3.middleware_onion.get("ens_name_to_address"): + w3.middleware_onion.remove("ens_name_to_address") if not w3.middleware_onion.get("stalecheck"): stalecheck_middleware = make_stalecheck_middleware( @@ -313,7 +313,7 @@ def init_async_web3( middlewares = list(middlewares) for i, (middleware, name) in enumerate(middlewares): - if name == "name_to_address": + if name == "ens_name_to_address": middlewares.pop(i) if "stalecheck" not in (name for mw, name in middlewares): diff --git a/tests/core/manager/test_default_middlewares.py b/tests/core/manager/test_default_middlewares.py index 179fd50ecb..1ad566808d 100644 --- a/tests/core/manager/test_default_middlewares.py +++ b/tests/core/manager/test_default_middlewares.py @@ -13,7 +13,7 @@ def test_default_sync_middlewares(w3): expected_middlewares = [ (gas_price_strategy_middleware, "gas_price_strategy"), - (ens_name_to_address_middleware, "name_to_address"), + (ens_name_to_address_middleware, "ens_name_to_address"), (attrdict_middleware, "attrdict"), (validation_middleware, "validation"), (buffered_gas_estimate_middleware, "gas_estimate"), diff --git a/tests/core/providers/test_async_http_provider.py b/tests/core/providers/test_async_http_provider.py index 892aff0bad..dbd1d8d79d 100644 --- a/tests/core/providers/test_async_http_provider.py +++ b/tests/core/providers/test_async_http_provider.py @@ -89,7 +89,7 @@ def test_web3_with_async_http_provider_has_default_middlewares_and_modules() -> == gas_price_strategy_middleware ) assert ( - async_w3.middleware_onion.get("name_to_address") + async_w3.middleware_onion.get("ens_name_to_address") == ens_name_to_address_middleware ) assert async_w3.middleware_onion.get("attrdict") == attrdict_middleware diff --git a/tests/core/providers/test_http_provider.py b/tests/core/providers/test_http_provider.py index 126ab5ffba..4567d18fe9 100644 --- a/tests/core/providers/test_http_provider.py +++ b/tests/core/providers/test_http_provider.py @@ -84,7 +84,9 @@ def test_web3_with_http_provider_has_default_middlewares_and_modules() -> None: assert ( w3.middleware_onion.get("gas_price_strategy") == gas_price_strategy_middleware ) - assert w3.middleware_onion.get("name_to_address") == ens_name_to_address_middleware + assert ( + w3.middleware_onion.get("ens_name_to_address") == ens_name_to_address_middleware + ) assert w3.middleware_onion.get("attrdict") == attrdict_middleware assert w3.middleware_onion.get("validation") == validation_middleware assert w3.middleware_onion.get("gas_estimate") == buffered_gas_estimate_middleware diff --git a/web3/manager.py b/web3/manager.py index 28428d7c7f..49b99955f4 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -163,7 +163,7 @@ def get_default_middlewares() -> List[Tuple[Middleware, str]]: """ return [ (gas_price_strategy_middleware, "gas_price_strategy"), - (ens_name_to_address_middleware, "name_to_address"), + (ens_name_to_address_middleware, "ens_name_to_address"), (attrdict_middleware, "attrdict"), (validation_middleware, "validation"), (buffered_gas_estimate_middleware, "gas_estimate"), From 35b245ade4685b3d651055bd7af6b365f56a1450 Mon Sep 17 00:00:00 2001 From: fselmo Date: Mon, 18 Dec 2023 14:37:08 -0700 Subject: [PATCH 21/24] Make caching more robust; add tests back for cached requests - Make request caching more robust by re-introducing the lock from the middleware as well as checking for "error" and None responses and not caching those. - Re-introduce the tests for the simple_cache_middleware as caching utils tests. [Unrelated] - Keep the decorator logic on ``make_request`` when using the ``request_mocker`` fixture so we can still utilized the request mocker for testing cached requests. --- .../caching-utils/test_request_caching.py | 238 ++++++++++++++++++ web3/_utils/caching.py | 29 ++- web3/_utils/module_testing/utils.py | 54 +++- web3/providers/async_base.py | 6 + web3/providers/base.py | 6 + 5 files changed, 316 insertions(+), 17 deletions(-) create mode 100644 tests/core/caching-utils/test_request_caching.py diff --git a/tests/core/caching-utils/test_request_caching.py b/tests/core/caching-utils/test_request_caching.py new file mode 100644 index 0000000000..a68d3c04e0 --- /dev/null +++ b/tests/core/caching-utils/test_request_caching.py @@ -0,0 +1,238 @@ +import itertools +import pytest +import threading +import uuid + +import pytest_asyncio + +from web3 import ( + AsyncWeb3, + Web3, +) +from web3._utils.caching import ( + generate_cache_key, +) +from web3.providers import ( + AsyncBaseProvider, + BaseProvider, +) +from web3.types import ( + RPCEndpoint, +) +from web3.utils import ( + SimpleCache, +) + + +def simple_cache_return_value_a(): + _cache = SimpleCache() + _cache.cache( + generate_cache_key(f"{threading.get_ident()}:{('fake_endpoint', [1])}"), + {"result": "value-a"}, + ) + return _cache + + +@pytest.fixture +def w3(request_mocker): + _w3 = Web3(provider=BaseProvider()) + _w3.provider.cache_allowed_requests = True + _w3.provider.cacheable_requests += (RPCEndpoint("fake_endpoint"),) + with request_mocker( + _w3, + mock_results={ + "fake_endpoint": lambda *_: uuid.uuid4(), + "not_on_allowlist": lambda *_: uuid.uuid4(), + }, + ): + yield _w3 + + # clear request cache after each test + _w3.provider._request_cache.clear() + + +def test_request_caching_pulls_from_cache(w3): + w3.provider._request_cache = simple_cache_return_value_a() + assert w3.manager.request_blocking("fake_endpoint", [1]) == "value-a" + + +def test_request_caching_populates_cache(w3): + result = w3.manager.request_blocking("fake_endpoint", []) + assert w3.manager.request_blocking("fake_endpoint", []) == result + assert w3.manager.request_blocking("fake_endpoint", [1]) != result + assert len(w3.provider._request_cache.items()) == 2 + + +def test_request_caching_does_not_cache_none_responses(request_mocker): + w3 = Web3(BaseProvider()) + w3.provider.cache_allowed_requests = True + w3.provider.cacheable_requests += (RPCEndpoint("fake_endpoint"),) + + counter = itertools.count() + + def result_cb(_method, _params): + next(counter) + return None + + with request_mocker(w3, mock_results={"fake_endpoint": result_cb}): + w3.manager.request_blocking("fake_endpoint", []) + w3.manager.request_blocking("fake_endpoint", []) + + assert next(counter) == 2 + + +def test_request_caching_does_not_cache_error_responses(request_mocker): + w3 = Web3(provider=BaseProvider()) + w3.provider.cache_allowed_requests = True + w3.provider.cacheable_requests += (RPCEndpoint("fake_endpoint"),) + + with request_mocker( + w3, mock_errors={"fake_endpoint": lambda *_: {"message": f"msg-{uuid.uuid4()}"}} + ): + with pytest.raises(ValueError) as err_a: + w3.manager.request_blocking("fake_endpoint", []) + with pytest.raises(ValueError) as err_b: + w3.manager.request_blocking("fake_endpoint", []) + + assert str(err_a) != str(err_b) + assert err_a.value.args != err_b.value.args + + +def test_request_caching_does_not_cache_endpoints_not_in_whitelist(w3): + result_a = w3.manager.request_blocking("not_on_allowlist", []) + result_b = w3.manager.request_blocking("not_on_allowlist", []) + assert result_a != result_b + + +def test_caching_requests_does_not_share_state_between_providers(request_mocker): + w3_a, w3_b, w3_c = ( + Web3(provider=BaseProvider()), + Web3(provider=BaseProvider()), + Web3(provider=BaseProvider()), + ) + mock_results_a = {RPCEndpoint("eth_chainId"): 11111} + mock_results_b = {RPCEndpoint("eth_chainId"): 22222} + mock_results_c = {RPCEndpoint("eth_chainId"): 33333} + + with request_mocker(w3_a, mock_results=mock_results_a): + with request_mocker(w3_b, mock_results=mock_results_b): + with request_mocker(w3_c, mock_results=mock_results_c): + result_a = w3_a.manager.request_blocking("eth_chainId", []) + result_b = w3_b.manager.request_blocking("eth_chainId", []) + result_c = w3_c.manager.request_blocking("eth_chainId", []) + + assert result_a != result_b != result_c + assert result_a == 11111 + assert result_b == 22222 + assert result_c == 33333 + + +# -- async -- # + + +@pytest_asyncio.fixture +async def async_w3(request_mocker): + _async_w3 = AsyncWeb3(AsyncBaseProvider()) + _async_w3.provider.cache_allowed_requests = True + _async_w3.provider.cacheable_requests += (RPCEndpoint("fake_endpoint"),) + async with request_mocker( + _async_w3, + mock_results={ + "fake_endpoint": lambda *_: uuid.uuid4(), + "not_on_allowlist": lambda *_: uuid.uuid4(), + }, + ): + yield _async_w3 + + # clear request cache after each test + _async_w3.provider._request_cache.clear() + + +@pytest.mark.asyncio +async def test_async_request_caching_pulls_from_cache(async_w3): + async_w3.provider._request_cache = simple_cache_return_value_a() + _result = await async_w3.manager.coro_request("fake_endpoint", [1]) + assert _result == "value-a" + + +@pytest.mark.asyncio +async def test_async_request_caching_populates_cache(async_w3): + result = await async_w3.manager.coro_request("fake_endpoint", []) + + _empty_params = await async_w3.manager.coro_request("fake_endpoint", []) + _non_empty_params = await async_w3.manager.coro_request("fake_endpoint", [1]) + + assert _empty_params == result + assert _non_empty_params != result + + +@pytest.mark.asyncio +async def test_async_request_caching_does_not_cache_none_responses(request_mocker): + async_w3 = AsyncWeb3(AsyncBaseProvider()) + async_w3.provider.cache_allowed_requests = True + async_w3.provider.cacheable_requests += (RPCEndpoint("fake_endpoint"),) + + counter = itertools.count() + + def result_cb(_method, _params): + next(counter) + return None + + async with request_mocker(async_w3, mock_results={"fake_endpoint": result_cb}): + await async_w3.manager.coro_request("fake_endpoint", []) + await async_w3.manager.coro_request("fake_endpoint", []) + + assert next(counter) == 2 + + +@pytest.mark.asyncio +async def test_async_request_caching_does_not_cache_error_responses(request_mocker): + async_w3 = AsyncWeb3(AsyncBaseProvider()) + async_w3.provider.cache_allowed_requests = True + async_w3.provider.cacheable_requests += (RPCEndpoint("fake_endpoint"),) + + async with request_mocker( + async_w3, + mock_errors={"fake_endpoint": lambda *_: {"message": f"msg-{uuid.uuid4()}"}}, + ): + with pytest.raises(ValueError) as err_a: + await async_w3.manager.coro_request("fake_endpoint", []) + with pytest.raises(ValueError) as err_b: + await async_w3.manager.coro_request("fake_endpoint", []) + + assert str(err_a) != str(err_b) + + +@pytest.mark.asyncio +async def test_async_request_caching_does_not_cache_non_whitelist_endpoints( + async_w3, +): + result_a = await async_w3.manager.coro_request("not_on_allowlist", []) + result_b = await async_w3.manager.coro_request("not_on_allowlist", []) + assert result_a != result_b + + +@pytest.mark.asyncio +async def test_async_request_caching_does_not_share_state_between_providers( + request_mocker, +): + async_w3_a, async_w3_b, async_w3_c = ( + AsyncWeb3(AsyncBaseProvider()), + AsyncWeb3(AsyncBaseProvider()), + AsyncWeb3(AsyncBaseProvider()), + ) + mock_results_a = {RPCEndpoint("eth_chainId"): 11111} + mock_results_b = {RPCEndpoint("eth_chainId"): 22222} + mock_results_c = {RPCEndpoint("eth_chainId"): 33333} + + async with request_mocker(async_w3_a, mock_results=mock_results_a): + async with request_mocker(async_w3_b, mock_results=mock_results_b): + async with request_mocker(async_w3_c, mock_results=mock_results_c): + result_a = await async_w3_a.manager.coro_request("eth_chainId", []) + result_b = await async_w3_b.manager.coro_request("eth_chainId", []) + result_c = await async_w3_c.manager.coro_request("eth_chainId", []) + + assert result_a != result_b != result_c + assert result_a == 11111 + assert result_b == 22222 + assert result_c == 33333 diff --git a/web3/_utils/caching.py b/web3/_utils/caching.py index 8d958607e7..a10e901cb4 100644 --- a/web3/_utils/caching.py +++ b/web3/_utils/caching.py @@ -1,5 +1,6 @@ import collections import hashlib +import threading from typing import ( TYPE_CHECKING, Any, @@ -85,6 +86,14 @@ def is_cacheable_request( # -- request caching decorators -- # +def _should_cache_response(response: "RPCResponse") -> bool: + return ( + "error" not in response + and "result" in response + and not is_null(response["result"]) + ) + + def handle_request_caching( func: Callable[[SYNC_PROVIDER_TYPE, "RPCEndpoint", Any], "RPCResponse"] ) -> Callable[..., "RPCResponse"]: @@ -93,17 +102,23 @@ def wrapper( ) -> "RPCResponse": if is_cacheable_request(provider, method): request_cache = provider._request_cache - cache_key = generate_cache_key((method, params)) + cache_key = generate_cache_key( + f"{threading.get_ident()}:{(method, params)}" + ) cache_result = request_cache.get_cache_entry(cache_key) if cache_result is not None: return cache_result else: response = func(provider, method, params) - request_cache.cache(cache_key, response) + if _should_cache_response(response): + with provider._request_cache_lock: + request_cache.cache(cache_key, response) return response else: return func(provider, method, params) + # save a reference to the decorator on the wrapped function + wrapper._decorator = handle_request_caching # type: ignore return wrapper @@ -120,15 +135,21 @@ async def wrapper( ) -> "RPCResponse": if is_cacheable_request(provider, method): request_cache = provider._request_cache - cache_key = generate_cache_key((method, params)) + cache_key = generate_cache_key( + f"{threading.get_ident()}:{(method, params)}" + ) cache_result = request_cache.get_cache_entry(cache_key) if cache_result is not None: return cache_result else: response = await func(provider, method, params) - request_cache.cache(cache_key, response) + if _should_cache_response(response): + async with provider._request_cache_lock: + request_cache.cache(cache_key, response) return response else: return await func(provider, method, params) + # save a reference to the decorator on the wrapped function + wrapper._decorator = async_handle_request_caching # type: ignore return wrapper diff --git a/web3/_utils/module_testing/utils.py b/web3/_utils/module_testing/utils.py index f7bd59de41..b2e30ea274 100644 --- a/web3/_utils/module_testing/utils.py +++ b/web3/_utils/module_testing/utils.py @@ -14,10 +14,6 @@ merge, ) -from web3.exceptions import ( - Web3ValidationError, -) - if TYPE_CHECKING: from web3 import ( # noqa: F401 AsyncWeb3, @@ -59,7 +55,7 @@ def __init__( self, w3: Union["AsyncWeb3", "Web3"], mock_results: Dict[Union["RPCEndpoint", str], Any] = None, - mock_errors: Dict[Union["RPCEndpoint", str], Dict[str, Any]] = None, + mock_errors: Dict[Union["RPCEndpoint", str], Any] = None, ): self.w3 = w3 self.mock_results = mock_results or {} @@ -101,20 +97,31 @@ def _mock_request_handler( mock_return = self.mock_results[method] if callable(mock_return): mock_return = mock_return(method, params) - return merge(response_dict, {"result": mock_return}) + mocked_response = merge(response_dict, {"result": mock_return}) elif method in self.mock_errors: error = self.mock_errors[method] - if not isinstance(error, dict): - raise Web3ValidationError("error must be a dict") + if callable(error): + error = error(method, params) code = error.get("code", -32000) message = error.get("message", "Mocked error") - return merge( + mocked_response = merge( response_dict, {"error": merge({"code": code, "message": message}, error)}, ) else: raise Exception("Invariant: unreachable code path") + decorator = getattr(self._make_request, "_decorator", None) + if decorator is not None: + # If the original make_request was decorated, we need to re-apply + # the decorator to the mocked make_request. This is necessary for + # the request caching decorator to work properly. + return decorator(lambda *_: mocked_response)( + self.w3.provider, method, params + ) + else: + return mocked_response + # -- async -- # async def __aenter__(self) -> "Self": setattr(self.w3.provider, "make_request", self._async_mock_request_handler) @@ -151,16 +158,37 @@ async def _async_mock_request_handler( elif iscoroutinefunction(mock_return): # this is the "correct" way to mock the async make_request mock_return = await mock_return(method, params) - return merge(response_dict, {"result": mock_return}) + + mocked_result = merge(response_dict, {"result": mock_return}) + elif method in self.mock_errors: error = self.mock_errors[method] - if not isinstance(error, dict): - raise Web3ValidationError("error must be a dict") + if callable(error): + error = error(method, params) + elif iscoroutinefunction(error): + error = await error(method, params) + code = error.get("code", -32000) message = error.get("message", "Mocked error") - return merge( + mocked_result = merge( response_dict, {"error": merge({"code": code, "message": message}, error)}, ) + else: raise Exception("Invariant: unreachable code path") + + decorator = getattr(self._make_request, "_decorator", None) + if decorator is not None: + # If the original make_request was decorated, we need to re-apply + # the decorator to the mocked make_request. This is necessary for + # the request caching decorator to work properly. + + async def _coro( + _provider: Any, _method: "RPCEndpoint", _params: Any + ) -> "RPCResponse": + return mocked_result + + return await decorator(_coro)(self.w3.provider, method, params) + else: + return mocked_result diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py index 19a3e6a37a..9cc59be7ab 100644 --- a/web3/providers/async_base.py +++ b/web3/providers/async_base.py @@ -1,3 +1,4 @@ +import asyncio import itertools from typing import ( TYPE_CHECKING, @@ -15,6 +16,9 @@ to_text, ) +from web3._utils.caching import ( + async_handle_request_caching, +) from web3._utils.encoding import ( FriendlyJsonSerde, Web3JsonEncoder, @@ -75,6 +79,7 @@ class AsyncBaseProvider: cache_allowed_requests: bool = False cacheable_requests: Set[RPCEndpoint] = CACHEABLE_REQUESTS _request_cache: SimpleCache + _request_cache_lock: asyncio.Lock = asyncio.Lock() def __init__(self) -> None: self._request_cache = SimpleCache(1000) @@ -97,6 +102,7 @@ async def request_func( ) return self._request_func_cache[-1] + @async_handle_request_caching async def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: raise NotImplementedError("Providers must implement this method") diff --git a/web3/providers/base.py b/web3/providers/base.py index a91e9d8708..81866f24fe 100644 --- a/web3/providers/base.py +++ b/web3/providers/base.py @@ -1,4 +1,5 @@ import itertools +import threading from typing import ( TYPE_CHECKING, Any, @@ -13,6 +14,9 @@ to_text, ) +from web3._utils.caching import ( + handle_request_caching, +) from web3._utils.encoding import ( FriendlyJsonSerde, Web3JsonEncoder, @@ -72,6 +76,7 @@ class BaseProvider: cache_allowed_requests: bool = False cacheable_requests: Set[RPCEndpoint] = CACHEABLE_REQUESTS _request_cache: SimpleCache + _request_cache_lock: threading.Lock = threading.Lock() def __init__(self) -> None: self._request_cache = SimpleCache(1000) @@ -102,6 +107,7 @@ def request_func( return self._request_func_cache[-1] + @handle_request_caching def make_request(self, method: RPCEndpoint, params: Any) -> RPCResponse: raise NotImplementedError("Providers must implement this method") From c3894c6a693fce4eba2f85e05fa50153f59a69ae Mon Sep 17 00:00:00 2001 From: fselmo Date: Tue, 16 Jan 2024 12:29:21 -0300 Subject: [PATCH 22/24] Changes from comments on PR #3169 --- docs/middleware.rst | 12 ------------ tests/core/caching-utils/test_request_caching.py | 6 ++---- tests/core/providers/test_http_request_retry.py | 12 ++++++++++-- web3/main.py | 13 +++++-------- web3/middleware/cache.py | 0 web3/middleware/exception_retry_request.py | 0 6 files changed, 17 insertions(+), 26 deletions(-) delete mode 100644 web3/middleware/cache.py delete mode 100644 web3/middleware/exception_retry_request.py diff --git a/docs/middleware.rst b/docs/middleware.rst index 1ff48d1b59..aa767bca0f 100644 --- a/docs/middleware.rst +++ b/docs/middleware.rst @@ -104,18 +104,6 @@ Buffered Gas Estimate ``min(w3.eth.estimate_gas + gas_buffer, gas_limit)`` where the gas_buffer default is 100,000 -HTTPRequestRetry -~~~~~~~~~~~~~~~~~~ - -.. py:method:: web3.middleware.http_retry_request_middleware - web3.middleware.async_http_retry_request_middleware - - This middleware is a default specifically for HTTPProvider that retries failed - requests that return the following errors: ``ConnectionError``, ``HTTPError``, ``Timeout``, - ``TooManyRedirects``. Additionally there is a whitelist that only allows certain - methods to be retried in order to not resend transactions, excluded methods are: - ``eth_sendTransaction``, ``personal_signAndSendTransaction``, ``personal_sendTransaction``. - Validation ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/tests/core/caching-utils/test_request_caching.py b/tests/core/caching-utils/test_request_caching.py index a68d3c04e0..6307e02dfa 100644 --- a/tests/core/caching-utils/test_request_caching.py +++ b/tests/core/caching-utils/test_request_caching.py @@ -98,7 +98,7 @@ def test_request_caching_does_not_cache_error_responses(request_mocker): assert err_a.value.args != err_b.value.args -def test_request_caching_does_not_cache_endpoints_not_in_whitelist(w3): +def test_request_caching_does_not_cache_endpoints_not_in_allowlist(w3): result_a = w3.manager.request_blocking("not_on_allowlist", []) result_b = w3.manager.request_blocking("not_on_allowlist", []) assert result_a != result_b @@ -121,7 +121,6 @@ def test_caching_requests_does_not_share_state_between_providers(request_mocker) result_b = w3_b.manager.request_blocking("eth_chainId", []) result_c = w3_c.manager.request_blocking("eth_chainId", []) - assert result_a != result_b != result_c assert result_a == 11111 assert result_b == 22222 assert result_c == 33333 @@ -204,7 +203,7 @@ async def test_async_request_caching_does_not_cache_error_responses(request_mock @pytest.mark.asyncio -async def test_async_request_caching_does_not_cache_non_whitelist_endpoints( +async def test_async_request_caching_does_not_cache_non_allowlist_endpoints( async_w3, ): result_a = await async_w3.manager.coro_request("not_on_allowlist", []) @@ -232,7 +231,6 @@ async def test_async_request_caching_does_not_share_state_between_providers( result_b = await async_w3_b.manager.coro_request("eth_chainId", []) result_c = await async_w3_c.manager.coro_request("eth_chainId", []) - assert result_a != result_b != result_c assert result_a == 11111 assert result_b == 22222 assert result_c == 33333 diff --git a/tests/core/providers/test_http_request_retry.py b/tests/core/providers/test_http_request_retry.py index 22d0c2debd..938eb963dc 100644 --- a/tests/core/providers/test_http_request_retry.py +++ b/tests/core/providers/test_http_request_retry.py @@ -15,6 +15,7 @@ AsyncHTTPProvider, AsyncWeb3, Web3, + WebsocketProviderV2, ) from web3.providers import ( HTTPProvider, @@ -40,7 +41,10 @@ def w3(): def test_default_request_retry_configuration_for_http_provider(): w3 = Web3(HTTPProvider()) - assert w3.provider.exception_retry_configuration == ExceptionRetryConfiguration() + assert ( + getattr(w3.provider, "exception_retry_configuration") + == ExceptionRetryConfiguration() + ) def test_check_if_retry_on_failure_false(): @@ -82,6 +86,9 @@ def test_exception_retry_config_is_strictly_on_http_provider(): w3 = Web3(IPCProvider()) assert not hasattr(w3.provider, "exception_retry_configuration") + w3 = AsyncWeb3.persistent_websocket(WebsocketProviderV2("ws://localhost:8546")) + assert not hasattr(w3.provider, "exception_retry_configuration") + @patch("web3.providers.rpc.rpc.make_post_request", side_effect=ConnectionError) def test_exception_retry_middleware_with_allow_list_kwarg(make_post_request_mock): @@ -116,7 +123,8 @@ def async_w3(): async def test_async_default_request_retry_configuration_for_http_provider(): async_w3 = AsyncWeb3(AsyncHTTPProvider()) assert ( - async_w3.provider.exception_retry_configuration == ExceptionRetryConfiguration() + getattr(async_w3.provider, "exception_retry_configuration") + == ExceptionRetryConfiguration() ) diff --git a/web3/main.py b/web3/main.py index 51a238d82f..604ea4fcdf 100644 --- a/web3/main.py +++ b/web3/main.py @@ -193,12 +193,17 @@ class BaseWeb3: # Managers RequestManager = DefaultRequestManager + manager: DefaultRequestManager # mypy types eth: Union[Eth, AsyncEth] net: Union[Net, AsyncNet] geth: Union[Geth, AsyncGeth] + @property + def middleware_onion(self) -> MiddlewareOnion: + return cast(MiddlewareOnion, self.manager.middleware_onion) + # Encoding and Decoding @staticmethod @wraps(to_bytes) @@ -400,10 +405,6 @@ def __init__( def is_connected(self, show_traceback: bool = False) -> bool: return self.provider.is_connected(show_traceback) - @property - def middleware_onion(self) -> MiddlewareOnion: - return cast(MiddlewareOnion, self.manager.middleware_onion) - @property def provider(self) -> BaseProvider: return cast(BaseProvider, self.manager.provider) @@ -468,10 +469,6 @@ def __init__( async def is_connected(self, show_traceback: bool = False) -> bool: return await self.provider.is_connected(show_traceback) - @property - def middleware_onion(self) -> MiddlewareOnion: - return cast(MiddlewareOnion, self.manager.middleware_onion) - @property def provider(self) -> AsyncBaseProvider: return cast(AsyncBaseProvider, self.manager.provider) diff --git a/web3/middleware/cache.py b/web3/middleware/cache.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/web3/middleware/exception_retry_request.py b/web3/middleware/exception_retry_request.py deleted file mode 100644 index e69de29bb2..0000000000 From a548a4e420b304d548270657b2868cb0be8ee6d6 Mon Sep 17 00:00:00 2001 From: fselmo Date: Thu, 18 Jan 2024 14:53:18 -0700 Subject: [PATCH 23/24] Use as_tuple_of_middlewares() to get around typing issues - Use an internal method for the ``NamedMiddlewareOnion`` to generate the tuple of middlewares, instead of relying on ``tuple()``. --- web3/datastructures.py | 31 ++++++++++++++++++++++++------- web3/providers/async_base.py | 9 ++++----- web3/providers/base.py | 11 +++++------ web3/providers/eth_tester/main.py | 22 ++++++++++++---------- 4 files changed, 45 insertions(+), 28 deletions(-) diff --git a/web3/datastructures.py b/web3/datastructures.py index 2aa3632d45..0b33f95ae3 100644 --- a/web3/datastructures.py +++ b/web3/datastructures.py @@ -268,13 +268,6 @@ def _replace_with_new_name(self, old: TKey, new: TKey) -> None: self._queue.move_to_end(key) del self._queue[old] - def __iter__(self) -> Iterator[TKey]: - elements = self._queue.values() - if not isinstance(elements, Sequence): - # type ignored b/c elements is set as _OrderedDictValuesView[Any] on 210 - elements = list(elements) # type: ignore - return iter(reversed(elements)) - def __add__(self, other: Any) -> "NamedElementOnion[TKey, TValue]": if not isinstance(other, NamedElementOnion): # you can only combine with another ``NamedElementOnion`` @@ -297,3 +290,27 @@ def __reversed__(self) -> Iterator[TValue]: if not isinstance(elements, Sequence): elements = list(elements) return iter(elements) + + # --- iter and tupleize methods --- # + + def _reversed_middlewares(self) -> Iterator[TValue]: + elements = self._queue.values() + if not isinstance(elements, Sequence): + # type ignored b/c elements is set as _OrderedDictValuesView[Any] on 210 + elements = list(elements) # type: ignore + return reversed(elements) + + def as_tuple_of_middlewares(self) -> Tuple[TValue, ...]: + """ + This helps with type hinting since we return `Iterator[TKey]` type, though it's + actually a `Iterator[TValue]` type, for the `__iter__()` method. This is in + order to satisfy the `Mapping` interface. + """ + return tuple(self._reversed_middlewares()) + + def __iter__(self) -> Iterator[TKey]: + # ``__iter__()`` for a ``Mapping`` returns ``Iterator[TKey]`` but this + # implementation returns ``Iterator[TValue]`` on reversed values (not keys). + # This leads to typing issues, so it's better to use + # ``as_tuple_of_middlewares()`` to achieve the same result. + return iter(self._reversed_middlewares()) # type: ignore diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py index 9cc59be7ab..471e44e66e 100644 --- a/web3/providers/async_base.py +++ b/web3/providers/async_base.py @@ -85,17 +85,16 @@ def __init__(self) -> None: self._request_cache = SimpleCache(1000) async def request_func( - self, async_w3: "AsyncWeb3", middlewares: MiddlewareOnion + self, async_w3: "AsyncWeb3", middleware_onion: MiddlewareOnion ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: - # type ignored b/c tuple(MiddlewareOnion) converts to tuple of middlewares - middlewares: Tuple[Middleware, ...] = tuple(middlewares) # type: ignore + middlewares: Tuple[Middleware, ...] = middleware_onion.as_tuple_of_middlewares() cache_key = self._request_func_cache[0] - if cache_key != middlewares: # type: ignore + if cache_key != middlewares: self._request_func_cache = ( middlewares, await async_combine_middlewares( - middlewares=middlewares, # type: ignore + middlewares=middlewares, async_w3=async_w3, provider_request_fn=self.make_request, ), diff --git a/web3/providers/base.py b/web3/providers/base.py index 81866f24fe..ba5d4c5962 100644 --- a/web3/providers/base.py +++ b/web3/providers/base.py @@ -82,24 +82,23 @@ def __init__(self) -> None: self._request_cache = SimpleCache(1000) def request_func( - self, w3: "Web3", middlewares: MiddlewareOnion + self, w3: "Web3", middleware_onion: MiddlewareOnion ) -> Callable[..., RPCResponse]: """ @param w3 is the web3 instance - @param middlewares is an iterable of middlewares, + @param middleware_onion is an iterable of middlewares, ordered by first to execute @returns a function that calls all the middleware and eventually self.make_request() """ - # type ignored b/c tuple(MiddlewareOnion) converts to tuple of middlewares - middlewares: Tuple[Middleware] = tuple(middlewares) # type: ignore + middlewares: Tuple[Middleware, ...] = middleware_onion.as_tuple_of_middlewares() cache_key = self._request_func_cache[0] - if cache_key != middlewares: # type: ignore + if cache_key != middlewares: self._request_func_cache = ( middlewares, combine_middlewares( - middlewares=middlewares, # type: ignore + middlewares=middlewares, w3=w3, provider_request_fn=self.make_request, ), diff --git a/web3/providers/eth_tester/main.py b/web3/providers/eth_tester/main.py index 45b95718c8..70cb3f7104 100644 --- a/web3/providers/eth_tester/main.py +++ b/web3/providers/eth_tester/main.py @@ -79,19 +79,20 @@ def __init__(self) -> None: self.api_endpoints = API_ENDPOINTS async def request_func( - self, async_w3: "AsyncWeb3", middlewares: "MiddlewareOnion" + self, async_w3: "AsyncWeb3", middleware_onion: "MiddlewareOnion" ) -> Callable[..., Coroutine[Any, Any, RPCResponse]]: # override the request_func to add the ethereum_tester_middleware - # type ignored bc tuple(MiddlewareOnion) converts to tuple of middlewares - middlewares = tuple(middlewares) + tuple(self._middlewares) # type: ignore + middlewares = middleware_onion.as_tuple_of_middlewares() + tuple( + self._middlewares + ) cache_key = self._request_func_cache[0] - if cache_key != middlewares: # type: ignore + if cache_key != middlewares: self._request_func_cache = ( middlewares, await async_combine_middlewares( - middlewares=middlewares, # type: ignore + middlewares=middlewares, async_w3=async_w3, provider_request_fn=self.make_request, ), @@ -154,19 +155,20 @@ def __init__( self.api_endpoints = api_endpoints def request_func( - self, w3: "Web3", middlewares: "MiddlewareOnion" + self, w3: "Web3", middleware_onion: "MiddlewareOnion" ) -> Callable[..., RPCResponse]: # override the request_func to add the ethereum_tester_middleware - # type ignored bc tuple(MiddlewareOnion) converts to tuple of middlewares - middlewares = tuple(middlewares) + tuple(self._middlewares) # type: ignore + middlewares = middleware_onion.as_tuple_of_middlewares() + tuple( + self._middlewares + ) cache_key = self._request_func_cache[0] - if cache_key != middlewares: # type: ignore + if cache_key != middlewares: self._request_func_cache = ( middlewares, combine_middlewares( - middlewares=middlewares, # type: ignore + middlewares=middlewares, w3=w3, provider_request_fn=self.make_request, ), From 91e872d63b4bb747cf2dcb014f515d6d062faebc Mon Sep 17 00:00:00 2001 From: fselmo Date: Thu, 18 Jan 2024 16:44:46 -0700 Subject: [PATCH 24/24] newsfragment for 3169 --- newsfragments/3169.breaking.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 newsfragments/3169.breaking.rst diff --git a/newsfragments/3169.breaking.rst b/newsfragments/3169.breaking.rst new file mode 100644 index 0000000000..60fad67045 --- /dev/null +++ b/newsfragments/3169.breaking.rst @@ -0,0 +1 @@ +Refactor the middleware setup so that request processors and response processors are separated. This will allow for more flexibility in the future and aid in the implementation of features such as batched requests. This PR also closes out a few outstanding issues and will be the start of the breaking changes for `web3.py` ``v7``. Review PR for a full list of changes.