diff --git a/docs/providers.rst b/docs/providers.rst index 1cae277aa5..8811620774 100644 --- a/docs/providers.rst +++ b/docs/providers.rst @@ -282,18 +282,10 @@ AsyncHTTPProvider >>> from web3.net import AsyncNet >>> from web3.geth import Geth, AsyncGethTxPool - >>> w3 = Web3( - ... AsyncHTTPProvider(endpoint_uri), - ... modules={'eth': (AsyncEth,), - ... 'net': (AsyncNet,), - ... 'geth': (Geth, - ... {'txpool': (AsyncGethTxPool,), - ... 'personal': (AsyncGethPersonal,), - ... 'admin' : (AsyncGethAdmin,)}) - ... }, - ... middlewares=[] # See supported middleware section below for middleware options - ... ) - >>> custom_session = ClientSession() # If you want to pass in your own session + >>> w3 = Web3(AsyncHTTPProvider(endpoint_uri)) + + >>> # If you want to pass in your own session: + >>> custom_session = ClientSession() >>> await w3.provider.cache_async_session(custom_session) # This method is an async method so it needs to be handled accordingly Under the hood, the ``AsyncHTTPProvider`` uses the python diff --git a/newsfragments/2736.feature.rst b/newsfragments/2736.feature.rst new file mode 100644 index 0000000000..7d5b742c9b --- /dev/null +++ b/newsfragments/2736.feature.rst @@ -0,0 +1 @@ +Load the ``AsyncHTTPProvider`` with default async middleware and default async modules, just as the ``HTTPProvider``. diff --git a/tests/core/contracts/test_contract_call_interface.py b/tests/core/contracts/test_contract_call_interface.py index 6ae2720714..7d1338e944 100644 --- a/tests/core/contracts/test_contract_call_interface.py +++ b/tests/core/contracts/test_contract_call_interface.py @@ -29,6 +29,7 @@ FallbackNotFound, InvalidAddress, MismatchedABI, + NameNotFound, NoABIFound, NoABIFunctionsFound, ValidationError, @@ -490,9 +491,9 @@ def test_call_address_reflector_name_array(address_reflector_contract, call): assert addresses == result -def test_call_reject_invalid_ens_name(address_reflector_contract, call): +def test_call_rejects_invalid_ens_name(address_reflector_contract, call): with contract_ens_addresses(address_reflector_contract, []): - with pytest.raises(ValueError): + with pytest.raises(NameNotFound): call( contract=address_reflector_contract, contract_function="reflect", @@ -1363,11 +1364,11 @@ async def test_async_call_address_reflector_name_array( @pytest.mark.xfail @pytest.mark.asyncio -async def test_async_call_reject_invalid_ens_name( +async def test_async_call_rejects_invalid_ens_name( async_address_reflector_contract, async_call ): with contract_ens_addresses(async_address_reflector_contract, []): - with pytest.raises(ValueError): + with pytest.raises(NameNotFound): await async_call( contract=async_address_reflector_contract, contract_function="reflect", diff --git a/tests/core/providers/test_async_http_provider.py b/tests/core/providers/test_async_http_provider.py index c5c3b6cc95..c852db91e0 100644 --- a/tests/core/providers/test_async_http_provider.py +++ b/tests/core/providers/test_async_http_provider.py @@ -4,18 +4,80 @@ ClientSession, ) +from web3 import Web3 from web3._utils import ( request, ) +from web3.eth import ( + AsyncEth, +) +from web3.geth import ( + AsyncGethAdmin, + AsyncGethPersonal, + AsyncGethTxPool, + Geth, +) +from web3.middleware import ( + async_buffered_gas_estimate_middleware, + async_gas_price_strategy_middleware, + async_validation_middleware, +) +from web3.net import ( + AsyncNet, +) from web3.providers.async_rpc import ( AsyncHTTPProvider, ) +URI = "http://mynode.local:8545" + + +def test_no_args(): + provider = AsyncHTTPProvider() + w3 = Web3(provider) + assert w3.manager.provider == provider + assert w3.manager.provider.is_async + + +def test_init_kwargs(): + provider = AsyncHTTPProvider(endpoint_uri=URI, request_kwargs={"timeout": 60}) + w3 = Web3(provider) + assert w3.manager.provider == provider + + +def test_web3_with_async_http_provider_has_default_middlewares_and_modules() -> None: + async_w3 = Web3(AsyncHTTPProvider(endpoint_uri=URI)) + + # assert default modules + + assert isinstance(async_w3.eth, AsyncEth) + assert isinstance(async_w3.net, AsyncNet) + assert isinstance(async_w3.geth, Geth) + assert isinstance(async_w3.geth.admin, AsyncGethAdmin) + assert isinstance(async_w3.geth.personal, AsyncGethPersonal) + assert isinstance(async_w3.geth.txpool, AsyncGethTxPool) + + # assert default middleware + + # the following length check should fail and will need to be added to once more + # async middlewares are added to the defaults + assert len(async_w3.middleware_onion.middlewares) == 3 + + assert ( + async_w3.middleware_onion.get("gas_price_strategy") + == async_gas_price_strategy_middleware + ) + assert async_w3.middleware_onion.get("validation") == async_validation_middleware + assert ( + async_w3.middleware_onion.get("gas_estimate") + == async_buffered_gas_estimate_middleware + ) + @pytest.mark.asyncio async def test_user_provided_session() -> None: session = ClientSession() - provider = AsyncHTTPProvider(endpoint_uri="http://mynode.local:8545") + provider = AsyncHTTPProvider(endpoint_uri=URI) cached_session = await provider.cache_async_session(session) assert len(request._async_session_cache) == 1 assert cached_session == session diff --git a/tests/core/providers/test_http_provider.py b/tests/core/providers/test_http_provider.py index 0e6e15b502..bef9e79469 100644 --- a/tests/core/providers/test_http_provider.py +++ b/tests/core/providers/test_http_provider.py @@ -20,6 +20,7 @@ def test_no_args(): provider = HTTPProvider() w3 = Web3(provider) assert w3.manager.provider == provider + assert not w3.manager.provider.is_async def test_init_kwargs(): diff --git a/tests/integration/go_ethereum/test_goethereum_http.py b/tests/integration/go_ethereum/test_goethereum_http.py index deeeeb775c..98460e3af9 100644 --- a/tests/integration/go_ethereum/test_goethereum_http.py +++ b/tests/integration/go_ethereum/test_goethereum_http.py @@ -12,23 +12,6 @@ from web3._utils.module_testing.go_ethereum_personal_module import ( GoEthereumAsyncPersonalModuleTest, ) -from web3.eth import ( - AsyncEth, -) -from web3.geth import ( - AsyncGethAdmin, - AsyncGethPersonal, - AsyncGethTxPool, - Geth, -) -from web3.middleware import ( - async_buffered_gas_estimate_middleware, - async_gas_price_strategy_middleware, - async_validation_middleware, -) -from web3.net import ( - AsyncNet, -) from web3.providers.async_rpc import ( AsyncHTTPProvider, ) @@ -137,26 +120,7 @@ class TestGoEthereumTxPoolModuleTest(GoEthereumTxPoolModuleTest): @pytest_asyncio.fixture(scope="module") async def async_w3(geth_process, endpoint_uri): await wait_for_aiohttp(endpoint_uri) - _w3 = Web3( - AsyncHTTPProvider(endpoint_uri), - middlewares=[ - async_buffered_gas_estimate_middleware, - async_gas_price_strategy_middleware, - async_validation_middleware, - ], - modules={ - "eth": AsyncEth, - "async_net": AsyncNet, - "geth": ( - Geth, - { - "txpool": (AsyncGethTxPool,), - "personal": (AsyncGethPersonal,), - "admin": (AsyncGethAdmin,), - }, - ), - }, - ) + _w3 = Web3(AsyncHTTPProvider(endpoint_uri)) return _w3 @@ -165,22 +129,22 @@ class TestGoEthereumAsyncAdminModuleTest(GoEthereumAsyncAdminModuleTest): @pytest.mark.xfail( reason="running geth with the --nodiscover flag doesn't allow peer addition" ) - async def test_admin_peers(self, w3: "Web3") -> None: - await super().test_admin_peers(w3) + async def test_admin_peers(self, async_w3: "Web3") -> None: + await super().test_admin_peers(async_w3) @pytest.mark.asyncio - async def test_admin_start_stop_http(self, w3: "Web3") -> None: + async def test_admin_start_stop_http(self, async_w3: "Web3") -> None: # This test causes all tests after it to fail on CI if it's allowed to run pytest.xfail( reason="Only one HTTP endpoint is allowed to be active at any time" ) - await super().test_admin_start_stop_http(w3) + await super().test_admin_start_stop_http(async_w3) @pytest.mark.asyncio - async def test_admin_start_stop_ws(self, w3: "Web3") -> None: + async def test_admin_start_stop_ws(self, async_w3: "Web3") -> None: # This test causes all tests after it to fail on CI if it's allowed to run pytest.xfail(reason="Only one WS endpoint is allowed to be active at any time") - await super().test_admin_start_stop_ws(w3) + await super().test_admin_start_stop_ws(async_w3) class TestGoEthereumAsyncNetModuleTest(GoEthereumAsyncNetModuleTest): diff --git a/web3/_utils/contracts.py b/web3/_utils/contracts.py index 348917d59e..26a03cc161 100644 --- a/web3/_utils/contracts.py +++ b/web3/_utils/contracts.py @@ -208,11 +208,13 @@ def encode_abi( ) normalizers = [ - abi_ens_resolver(w3), abi_address_to_hex, abi_bytes_to_bytes, abi_string_to_text, ] + if not w3.eth.is_async: + normalizers.append(abi_ens_resolver(w3)) + normalized_arguments = map_abi_data( normalizers, argument_types, diff --git a/web3/_utils/module_testing/net_module.py b/web3/_utils/module_testing/net_module.py index 2ece0efb7f..80c44a1d6d 100644 --- a/web3/_utils/module_testing/net_module.py +++ b/web3/_utils/module_testing/net_module.py @@ -34,19 +34,19 @@ def test_net_peer_count(self, w3: "Web3") -> None: class AsyncNetModuleTest: @pytest.mark.asyncio async def test_net_version(self, async_w3: "Web3") -> None: - version = await async_w3.async_net.version + version = await async_w3.net.version # type: ignore assert is_string(version) assert version.isdigit() @pytest.mark.asyncio async def test_net_listening(self, async_w3: "Web3") -> None: - listening = await async_w3.async_net.listening + listening = await async_w3.net.listening # type: ignore assert is_boolean(listening) @pytest.mark.asyncio async def test_net_peer_count(self, async_w3: "Web3") -> None: - peer_count = await async_w3.async_net.peer_count + peer_count = await async_w3.net.peer_count # type: ignore assert is_integer(peer_count) diff --git a/web3/_utils/normalizers.py b/web3/_utils/normalizers.py index f287cf9995..f7c1e38a8f 100644 --- a/web3/_utils/normalizers.py +++ b/web3/_utils/normalizers.py @@ -205,7 +205,11 @@ def abi_address_to_hex( @curry -def abi_ens_resolver(w3: "Web3", type_str: TypeStr, val: Any) -> Tuple[TypeStr, Any]: +def abi_ens_resolver( + w3: "Web3", + type_str: TypeStr, + val: Any, +) -> Tuple[TypeStr, Any]: if type_str == "address" and is_ens_name(val): if w3 is None: raise InvalidAddress( @@ -214,11 +218,12 @@ def abi_ens_resolver(w3: "Web3", type_str: TypeStr, val: Any) -> Tuple[TypeStr, ) _ens = cast(ENS, w3.ens) + net_version = int(w3.net.version) if hasattr(w3, "net") else None if _ens is None: raise InvalidAddress( f"Could not look up name {val!r} because ENS is" " set to None" ) - elif int(w3.net.version) != 1 and not isinstance(_ens, StaticENS): + elif net_version != 1 and not isinstance(_ens, StaticENS): raise InvalidAddress( f"Could not look up name {val!r} because web3 is" " not connected to mainnet" diff --git a/web3/main.py b/web3/main.py index d6f988952b..cea74e2985 100644 --- a/web3/main.py +++ b/web3/main.py @@ -74,9 +74,13 @@ abi_ens_resolver, ) from web3.eth import ( + AsyncEth, Eth, ) from web3.geth import ( + AsyncGethAdmin, + AsyncGethPersonal, + AsyncGethTxPool, Geth, GethAdmin, GethMiner, @@ -126,6 +130,21 @@ from web3._utils.empty import Empty # noqa: F401 +def get_async_default_modules() -> Dict[str, Union[Type[Module], Sequence[Any]]]: + return { + "eth": AsyncEth, + "net": AsyncNet, + "geth": ( + Geth, + { + "admin": AsyncGethAdmin, + "personal": AsyncGethPersonal, + "txpool": AsyncGethTxPool, + }, + ), + } + + def get_default_modules() -> Dict[str, Union[Type[Module], Sequence[Any]]]: return { "eth": Eth, @@ -219,7 +238,6 @@ def to_checksum_address(value: Union[AnyAddress, str, bytes]) -> ChecksumAddress eth: Eth geth: Geth net: Net - async_net: AsyncNet def __init__( self, @@ -237,7 +255,11 @@ def __init__( self.codec = ABICodec(build_default_registry()) if modules is None: - modules = get_default_modules() + modules = ( + get_async_default_modules() + if provider and provider.is_async + else get_default_modules() + ) self.attach_modules(modules) diff --git a/web3/manager.py b/web3/manager.py index a18b2f5730..ffab72d7d8 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -37,6 +37,9 @@ ) from web3.middleware import ( abi_middleware, + async_buffered_gas_estimate_middleware, + async_gas_price_strategy_middleware, + async_validation_middleware, attrdict_middleware, buffered_gas_estimate_middleware, gas_price_strategy_middleware, @@ -49,6 +52,7 @@ AutoProvider, ) from web3.types import ( # noqa: F401 + AsyncMiddleware, Middleware, MiddlewareOnion, RPCEndpoint, @@ -101,16 +105,20 @@ def __init__( self.w3 = w3 self.pending_requests: Dict[UUID, ThreadWithReturn[RPCResponse]] = {} - if middlewares is None: - middlewares = self.default_middlewares(w3) - - self.middleware_onion: MiddlewareOnion = NamedElementOnion(middlewares) - if provider is None: self.provider = AutoProvider() else: self.provider = provider + if middlewares is None: + middlewares = ( + self.async_default_middlewares(w3) + if self.provider.is_async + else self.default_middlewares(w3) + ) + + self.middleware_onion: MiddlewareOnion = NamedElementOnion(middlewares) + w3: "Web3" = None _provider = None @@ -139,6 +147,17 @@ def default_middlewares(w3: "Web3") -> List[Tuple[Middleware, str]]: (buffered_gas_estimate_middleware, "gas_estimate"), ] + @staticmethod + def async_default_middlewares(w3: "Web3") -> List[Tuple[Middleware, str]]: + """ + List the default async middlewares for the request manager. + """ + return [ + (async_gas_price_strategy_middleware, "gas_price_strategy"), + (async_validation_middleware, "validation"), + (async_buffered_gas_estimate_middleware, "gas_estimate"), + ] + # # Provider requests and response # diff --git a/web3/providers/async_base.py b/web3/providers/async_base.py index bdf4a9d11a..b411a883da 100644 --- a/web3/providers/async_base.py +++ b/web3/providers/async_base.py @@ -39,6 +39,7 @@ class AsyncBaseProvider: None, ) + is_async = True global_ccip_read_enabled: bool = True ccip_read_max_redirects: int = 4 diff --git a/web3/providers/base.py b/web3/providers/base.py index 2155322099..9c40fb6e38 100644 --- a/web3/providers/base.py +++ b/web3/providers/base.py @@ -38,6 +38,7 @@ class BaseProvider: None, ) + is_async = False global_ccip_read_enabled: bool = True ccip_read_max_redirects: int = 4