diff --git a/newsfragments/3120.feature.rst b/newsfragments/3120.feature.rst new file mode 100644 index 0000000000..2b60197aee --- /dev/null +++ b/newsfragments/3120.feature.rst @@ -0,0 +1 @@ +Add ``allow_list`` kwarg for ``exception_retry_middleware`` to allow for a custom list of RPC endpoints. Add a sleep between retries and a customizable ``backoff_factor`` to control the sleep time between retry attempts. diff --git a/tests/core/middleware/test_http_request_retry.py b/tests/core/middleware/test_http_request_retry.py index 27a96f7992..02060b5d85 100644 --- a/tests/core/middleware/test_http_request_retry.py +++ b/tests/core/middleware/test_http_request_retry.py @@ -99,6 +99,27 @@ def test_check_with_all_middlewares(make_post_request_mock): assert make_post_request_mock.call_count == 5 +@patch("web3.providers.rpc.make_post_request", side_effect=ConnectionError) +def test_exception_retry_middleware_with_allow_list_kwarg( + make_post_request_mock, exception_retry_request_setup +): + w3 = Mock() + provider = HTTPProvider() + errors = (ConnectionError, HTTPError, Timeout, TooManyRedirects) + setup = exception_retry_middleware( + provider.make_request, w3, errors, 5, allow_list=["test_userProvidedMethod"] + ) + setup.w3 = w3 + with pytest.raises(ConnectionError): + setup("test_userProvidedMethod", []) + assert make_post_request_mock.call_count == 5 + + make_post_request_mock.reset_mock() + with pytest.raises(ConnectionError): + setup("eth_getBalance", []) + assert make_post_request_mock.call_count == 1 + + # -- async -- # @@ -132,27 +153,56 @@ async def async_exception_retry_request_setup(): aiohttp.ClientOSError, ), ) -async def test_check_retry_middleware(error, async_exception_retry_request_setup): +async def test_async_check_retry_middleware(error, async_exception_retry_request_setup): with patch( "web3.providers.async_rpc.async_make_post_request" - ) as make_post_request_mock: - make_post_request_mock.side_effect = error + ) as async_make_post_request_mock: + async_make_post_request_mock.side_effect = error with pytest.raises(error): await async_exception_retry_request_setup("eth_getBalance", []) - assert make_post_request_mock.call_count == ASYNC_TEST_RETRY_COUNT + assert async_make_post_request_mock.call_count == ASYNC_TEST_RETRY_COUNT @pytest.mark.asyncio -async def test_check_without_retry_middleware(): +async def test_async_check_without_retry_middleware(): with patch( "web3.providers.async_rpc.async_make_post_request" - ) as make_post_request_mock: - make_post_request_mock.side_effect = TimeoutError + ) as async_make_post_request_mock: + async_make_post_request_mock.side_effect = TimeoutError provider = AsyncHTTPProvider() w3 = AsyncWeb3(provider) w3.provider._middlewares = () with pytest.raises(TimeoutError): await w3.eth.block_number - assert make_post_request_mock.call_count == 1 + assert async_make_post_request_mock.call_count == 1 + + +@pytest.mark.asyncio +async def test_async_exception_retry_middleware_with_allow_list_kwarg(): + w3 = Mock() + provider = AsyncHTTPProvider() + setup = await async_exception_retry_middleware( + provider.make_request, + w3, + (TimeoutError, aiohttp.ClientError), + retries=ASYNC_TEST_RETRY_COUNT, + backoff_factor=0.1, + allow_list=["test_userProvidedMethod"], + ) + setup.w3 = w3 + + with patch( + "web3.providers.async_rpc.async_make_post_request" + ) as async_make_post_request_mock: + async_make_post_request_mock.side_effect = TimeoutError + + with pytest.raises(TimeoutError): + await setup("test_userProvidedMethod", []) + assert async_make_post_request_mock.call_count == ASYNC_TEST_RETRY_COUNT + + async_make_post_request_mock.reset_mock() + with pytest.raises(TimeoutError): + await setup("eth_getBalance", []) + assert async_make_post_request_mock.call_count == 1 diff --git a/web3/middleware/exception_retry_request.py b/web3/middleware/exception_retry_request.py index 616687e8b9..81f4ef0743 100644 --- a/web3/middleware/exception_retry_request.py +++ b/web3/middleware/exception_retry_request.py @@ -1,9 +1,11 @@ import asyncio +import time from typing import ( TYPE_CHECKING, Any, Callable, Collection, + List, Optional, Type, ) @@ -28,7 +30,7 @@ Web3, ) -whitelist = [ +DEFAULT_ALLOWLIST = [ "admin", "miner", "net", @@ -87,11 +89,16 @@ ] -def check_if_retry_on_failure(method: RPCEndpoint) -> bool: +def check_if_retry_on_failure( + method: str, allow_list: Optional[List[str]] = None +) -> bool: + if allow_list is None: + allow_list = DEFAULT_ALLOWLIST + root = method.split("_")[0] - if root in whitelist: + if root in allow_list: return True - elif method in whitelist: + elif method in allow_list: return True else: return False @@ -102,6 +109,8 @@ def exception_retry_middleware( _w3: "Web3", errors: Collection[Type[BaseException]], retries: int = 5, + backoff_factor: float = 0.3, + allow_list: Optional[List[str]] = None, ) -> Callable[[RPCEndpoint, Any], RPCResponse]: """ Creates middleware that retries failed HTTP requests. Is a default @@ -109,12 +118,13 @@ def exception_retry_middleware( """ def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: - if check_if_retry_on_failure(method): + if check_if_retry_on_failure(method, allow_list): for i in range(retries): try: return make_request(method, params) except tuple(errors): if i < retries - 1: + time.sleep(backoff_factor) continue else: raise @@ -133,12 +143,16 @@ def http_retry_request_middleware( ) +# -- async -- # + + async def async_exception_retry_middleware( make_request: Callable[[RPCEndpoint, Any], Any], _async_w3: "AsyncWeb3", errors: Collection[Type[BaseException]], retries: int = 5, backoff_factor: float = 0.3, + allow_list: Optional[List[str]] = None, ) -> AsyncMiddlewareCoroutine: """ Creates middleware that retries failed HTTP requests. @@ -146,7 +160,7 @@ async def async_exception_retry_middleware( """ async def middleware(method: RPCEndpoint, params: Any) -> Optional[RPCResponse]: - if check_if_retry_on_failure(method): + if check_if_retry_on_failure(method, allow_list): for i in range(retries): try: return await make_request(method, params)