diff --git a/docs/abi_types.rst b/docs/abi_types.rst index 6c3c582462..0a9097cca3 100644 --- a/docs/abi_types.rst +++ b/docs/abi_types.rst @@ -20,10 +20,11 @@ Ethereum Addresses All addresses must be supplied in one of three ways: -* While connected to mainnet, an Ethereum Name Service name (often in the form ``myname.eth``) * A 20-byte hexadecimal that is checksummed using the `EIP-55 `_ spec. -* A 20-byte binary address. +* A 20-byte binary address (python bytes type). +* While connected to an Ethereum Name Service (ENS) supported chain, an ENS name + (often in the form ``myname.eth``). Disabling Strict Bytes Type Checking ------------------------------------ diff --git a/docs/middleware.rst b/docs/middleware.rst index 166cdb41cb..64215a8763 100644 --- a/docs/middleware.rst +++ b/docs/middleware.rst @@ -40,6 +40,7 @@ Sync middlewares include: Async middlewares include: * ``gas_price_strategy`` +* ``name_to_address`` * ``attrdict`` * ``validation`` * ``gas_estimate`` @@ -66,15 +67,16 @@ AttributeDict ~~~~~~~~~~~~~~~~~~~~~ .. py:method:: web3.middleware.name_to_address_middleware + web3.middleware.async_name_to_address_middleware This middleware converts Ethereum Name Service (ENS) names into the address that the name points to. For example :meth:`w3.eth.send_transaction ` will accept .eth names in the 'from' and 'to' fields. .. note:: - This middleware only converts ENS names if invoked with the mainnet - (where the ENS contract is deployed), for all other cases will result in an - ``InvalidAddress`` error + This middleware only converts ENS names on chains where the proper ENS + contracts are deployed to support this functionality. All other cases will + result in a ``NameNotFound`` error. Gas Price Strategy ~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/newsfragments/3012.docs.rst b/newsfragments/3012.docs.rst new file mode 100644 index 0000000000..1a83d51241 --- /dev/null +++ b/newsfragments/3012.docs.rst @@ -0,0 +1 @@ +Update documentation relating to ENS only being available on mainnet. ENS is available on all networks where the ENS contracts are deployed. diff --git a/newsfragments/3012.feature.rst b/newsfragments/3012.feature.rst new file mode 100644 index 0000000000..6164dd5c44 --- /dev/null +++ b/newsfragments/3012.feature.rst @@ -0,0 +1 @@ +Add async support for ENS name-to-address resolution via ``async_name_to_address_middleware``. diff --git a/tests/core/manager/test_default_middlewares.py b/tests/core/manager/test_default_middlewares.py index 59c3736335..64b8a5d1d2 100644 --- a/tests/core/manager/test_default_middlewares.py +++ b/tests/core/manager/test_default_middlewares.py @@ -6,6 +6,7 @@ async_attrdict_middleware, async_buffered_gas_estimate_middleware, async_gas_price_strategy_middleware, + async_name_to_address_middleware, async_validation_middleware, attrdict_middleware, buffered_gas_estimate_middleware, @@ -15,7 +16,7 @@ ) -def test_default_sync_middlwares(w3): +def test_default_sync_middlewares(w3): expected_middlewares = [ (gas_price_strategy_middleware, "gas_price_strategy"), (name_to_address_middleware(w3), "name_to_address"), @@ -32,9 +33,10 @@ def test_default_sync_middlwares(w3): assert default_middlewares[x][1] == expected_middlewares[x][1] -def test_default_async_middlwares(): +def test_default_async_middlewares(): expected_middlewares = [ (async_gas_price_strategy_middleware, "gas_price_strategy"), + (async_name_to_address_middleware, "name_to_address"), (async_attrdict_middleware, "attrdict"), (async_validation_middleware, "validation"), (async_buffered_gas_estimate_middleware, "gas_estimate"), diff --git a/tests/core/middleware/test_name_to_address_middleware.py b/tests/core/middleware/test_name_to_address_middleware.py index da8406ef4b..d6dc040942 100644 --- a/tests/core/middleware/test_name_to_address_middleware.py +++ b/tests/core/middleware/test_name_to_address_middleware.py @@ -1,23 +1,27 @@ import pytest +import pytest_asyncio + from web3 import ( + AsyncWeb3, Web3, - constants, ) from web3.exceptions import ( InvalidAddress, + NameNotFound, ) -from web3.middleware import ( # noqa: F401 - construct_fixture_middleware, +from web3.middleware import ( name_to_address_middleware, ) -from web3.providers.base import ( - BaseProvider, +from web3.middleware.names import ( + async_name_to_address_middleware, +) +from web3.providers.eth_tester import ( + AsyncEthereumTesterProvider, + EthereumTesterProvider, ) -NAME = "dump.eth" -ADDRESS = constants.ADDRESS_ZERO -BALANCE = 0 +NAME = "tester.eth" class TempENS: @@ -29,31 +33,150 @@ def address(self, name): @pytest.fixture -def w3(): - w3 = Web3(provider=BaseProvider(), middlewares=[]) - w3.ens = TempENS({NAME: ADDRESS}) - w3.middleware_onion.add(name_to_address_middleware(w3)) - return w3 +def _w3_setup(): + return Web3(provider=EthereumTesterProvider(), middlewares=[]) + + +@pytest.fixture +def ens_mapped_address(_w3_setup): + return _w3_setup.eth.accounts[0] + + +@pytest.fixture +def ens_addr_account_balance(ens_mapped_address, _w3_setup): + return _w3_setup.eth.get_balance(ens_mapped_address) + + +@pytest.fixture +def w3(_w3_setup, ens_mapped_address): + _w3_setup.ens = TempENS({NAME: ens_mapped_address}) + _w3_setup.middleware_onion.add(name_to_address_middleware(_w3_setup)) + return _w3_setup + +@pytest.mark.parametrize( + "params", + ( + [NAME, "latest"], + [NAME, 0], + [NAME], + ), +) +def test_pass_name_resolver_get_balance_list_args( + w3, + ens_addr_account_balance, + params, +): + assert w3.eth.get_balance(*params) == ens_addr_account_balance -def test_pass_name_resolver(w3): - return_chain_on_mainnet = construct_fixture_middleware( + +@pytest.mark.parametrize( + "params", + ( + {"value": 1, "from": NAME, "to": NAME, "gas": 21000}, { - "net_version": "1", - } - ) - return_balance = construct_fixture_middleware({"eth_getBalance": BALANCE}) - w3.middleware_onion.inject(return_chain_on_mainnet, layer=0) - w3.middleware_onion.inject(return_balance, layer=0) - assert w3.eth.get_balance(NAME) == BALANCE + "value": 1, + "maxPriorityFeePerGas": 10**9, + "from": NAME, + "to": NAME, + "gas": 21000, + }, + {"value": 1, "to": NAME, "gas": 21000, "from": NAME}, + ), +) +def test_pass_name_resolver_send_transaction_dict_args( + w3, + params, + ens_mapped_address, +): + tx_hash = w3.eth.send_transaction(params) + + tx = w3.eth.get_transaction(tx_hash) + assert tx["from"] == ens_mapped_address + assert tx["to"] == ens_mapped_address def test_fail_name_resolver(w3): - return_chain_on_mainnet = construct_fixture_middleware( - { - "net_version": "2", - } - ) - w3.middleware_onion.inject(return_chain_on_mainnet, layer=0) with pytest.raises(InvalidAddress, match=r".*ethereum\.eth.*"): w3.eth.get_balance("ethereum.eth") + + +# --- async --- # + + +class AsyncTempENS(TempENS): + async def address(self, name): + return self.registry.get(name, None) + + +@pytest_asyncio.fixture +async def _async_w3_setup(): + return AsyncWeb3(provider=AsyncEthereumTesterProvider(), middlewares=[]) + + +@pytest_asyncio.fixture +async def async_ens_mapped_address(_async_w3_setup): + accts = await _async_w3_setup.eth.accounts + return accts[0] + + +@pytest_asyncio.fixture +async def async_ens_addr_account_balance(async_ens_mapped_address, _async_w3_setup): + return await _async_w3_setup.eth.get_balance(async_ens_mapped_address) + + +@pytest_asyncio.fixture +async def async_w3(_async_w3_setup, async_ens_mapped_address): + _async_w3_setup.ens = AsyncTempENS({NAME: async_ens_mapped_address}) + _async_w3_setup.middleware_onion.add(async_name_to_address_middleware) + return _async_w3_setup + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "params", + ( + [NAME, "latest"], + [NAME, 0], + [NAME], + ), +) +async def test_async_pass_name_resolver_get_balance_list_args( + async_w3, + async_ens_addr_account_balance, + params, +): + assert await async_w3.eth.get_balance(*params) == async_ens_addr_account_balance + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "params", + ( + {"value": 1, "from": NAME, "to": NAME, "gas": 21000}, + { + "value": 1, + "maxPriorityFeePerGas": 10**9, + "from": NAME, + "to": NAME, + "gas": 21000, + }, + {"value": 1, "to": NAME, "gas": 21000, "from": NAME}, + ), +) +async def test_async_pass_name_resolver_send_transaction_dict_args( + async_w3, + params, + async_ens_mapped_address, +): + tx_hash = await async_w3.eth.send_transaction(params) + + tx = await async_w3.eth.get_transaction(tx_hash) + assert tx["from"] == async_ens_mapped_address + assert tx["to"] == async_ens_mapped_address + + +@pytest.mark.asyncio +async def test_async_fail_name_resolver(async_w3): + with pytest.raises(NameNotFound, match=r".*ethereum\.eth.*"): + await async_w3.eth.get_balance("ethereum.eth") diff --git a/tests/core/providers/test_async_http_provider.py b/tests/core/providers/test_async_http_provider.py index 05ab031f5f..eb18947e5e 100644 --- a/tests/core/providers/test_async_http_provider.py +++ b/tests/core/providers/test_async_http_provider.py @@ -26,6 +26,7 @@ async_attrdict_middleware, async_buffered_gas_estimate_middleware, async_gas_price_strategy_middleware, + async_name_to_address_middleware, async_validation_middleware, ) from web3.net import ( @@ -81,12 +82,16 @@ def test_web3_with_async_http_provider_has_default_middlewares_and_modules() -> # the following length check should fail and will need to be added to once more # async middlewares are added to the defaults - assert len(async_w3.middleware_onion.middlewares) == 4 + assert len(async_w3.middleware_onion.middlewares) == 5 assert ( async_w3.middleware_onion.get("gas_price_strategy") == async_gas_price_strategy_middleware ) + assert ( + async_w3.middleware_onion.get("name_to_address") + == async_name_to_address_middleware + ) assert async_w3.middleware_onion.get("attrdict") == async_attrdict_middleware assert async_w3.middleware_onion.get("validation") == async_validation_middleware assert ( diff --git a/web3/_utils/abi.py b/web3/_utils/abi.py index cfaf9cc5d7..4b0893d9cc 100644 --- a/web3/_utils/abi.py +++ b/web3/_utils/abi.py @@ -7,9 +7,11 @@ import itertools import re from typing import ( + TYPE_CHECKING, Any, Callable, Collection, + Coroutine, Dict, Iterable, List, @@ -53,6 +55,7 @@ decode_hex, is_bytes, is_list_like, + is_string, is_text, to_text, to_tuple, @@ -66,6 +69,9 @@ pipe, ) +from web3._utils.decorators import ( + reject_recursive_repeats, +) from web3._utils.ens import ( is_ens_name, ) @@ -82,11 +88,17 @@ ABIEventParams, ABIFunction, ABIFunctionParams, + TReturn, ) from web3.utils import ( # public utils module get_abi_input_names, ) +if TYPE_CHECKING: + from web3 import ( # noqa: F401 + AsyncWeb3, + ) + def filter_by_type(_type: str, contract_abi: ABI) -> List[Union[ABIFunction, ABIEvent]]: return [abi for abi in contract_abi if abi["type"] == _type] @@ -971,3 +983,70 @@ def __new__(self, args: Any) -> "ABIDecodedNamedTuple": return super().__new__(self, *args) return ABIDecodedNamedTuple + + +# -- async -- # + + +async def async_data_tree_map( + async_w3: "AsyncWeb3", + func: Callable[ + ["AsyncWeb3", TypeStr, Any], Coroutine[Any, Any, Tuple[TypeStr, Any]] + ], + data_tree: Any, +) -> "ABITypedData": + """ + Map an awaitable method to every ABITypedData element in the tree. + + The awaitable method should receive three positional args: + async_w3, abi_type, and data + """ + + async def async_map_to_typed_data(elements: Any) -> "ABITypedData": + if isinstance(elements, ABITypedData) and elements.abi_type is not None: + formatted = await func(async_w3, *elements) + return ABITypedData(formatted) + else: + return elements + + return await async_recursive_map(async_w3, async_map_to_typed_data, data_tree) + + +@reject_recursive_repeats +async def async_recursive_map( + async_w3: "AsyncWeb3", + func: Callable[[Any], Coroutine[Any, Any, TReturn]], + data: Any, +) -> TReturn: + """ + Apply an awaitable method to data and any collection items inside data + (using async_map_collection). + + Define the awaitable method so that it only applies to the type of value that you + want it to apply to. + """ + + async def async_recurse(item: Any) -> TReturn: + return await async_recursive_map(async_w3, func, item) + + items_mapped = await async_map_if_collection(async_recurse, data) + return await func(items_mapped) + + +async def async_map_if_collection( + func: Callable[[Any], Coroutine[Any, Any, Any]], value: Any +) -> Any: + """ + Apply an awaitable method to each element of a collection or value of a dictionary. + If the value is not a collection, return it unmodified. + """ + + datatype = type(value) + if isinstance(value, Mapping): + return datatype({key: await func(val) for key, val in value.values()}) + if is_string(value): + return value + elif isinstance(value, Iterable): + return datatype([await func(item) for item in value]) + else: + return value diff --git a/web3/_utils/ens.py b/web3/_utils/ens.py index 41f0e3ba6e..a1bd8a1437 100644 --- a/web3/_utils/ens.py +++ b/web3/_utils/ens.py @@ -90,3 +90,15 @@ def contract_ens_addresses( """ with ens_addresses(contract.w3, name_addr_pairs): yield + + +# --- async --- # + + +async def async_validate_name_has_address( + async_ens: AsyncENS, name: str +) -> ChecksumAddress: + addr = await async_ens.address(name) + if not addr: + raise NameNotFound(f"Could not find address for name {name!r}") + return addr diff --git a/web3/_utils/normalizers.py b/web3/_utils/normalizers.py index e32ddb606d..a748c1a6b3 100644 --- a/web3/_utils/normalizers.py +++ b/web3/_utils/normalizers.py @@ -41,6 +41,7 @@ from ens import ( ENS, + AsyncENS, ) from web3._utils.encoding import ( hexstr_if_str, @@ -48,6 +49,7 @@ ) from web3._utils.ens import ( StaticENS, + async_validate_name_has_address, is_ens_name, validate_name_has_address, ) @@ -57,13 +59,17 @@ ) from web3.exceptions import ( InvalidAddress, + NameNotFound, ) from web3.types import ( ABI, ) if TYPE_CHECKING: - from web3 import Web3 # noqa: F401 + from web3 import ( # noqa: F401 + AsyncWeb3, + Web3, + ) def implicitly_identity( @@ -214,18 +220,20 @@ def abi_ens_resolver( ) _ens = cast(ENS, w3.ens) - net_version = int(w3.net.version) if hasattr(w3, "net") else None if _ens is None: raise InvalidAddress( - f"Could not look up name {val!r} because ENS is" " set to None" - ) - elif net_version != 1 and not isinstance(_ens, StaticENS): - raise InvalidAddress( - f"Could not look up name {val!r} because web3 is" - " not connected to mainnet" + f"Could not look up name {val!r} because ENS is set to None" ) else: - return type_str, validate_name_has_address(_ens, val) + try: + return type_str, validate_name_has_address(_ens, val) + except NameNotFound as e: + # TODO: This try/except is to keep backwards compatibility when we + # removed the mainnet requirement. Remove this in web3.py v7 and allow + # NameNotFound to raise. + if not isinstance(_ens, StaticENS): + raise InvalidAddress(f"{e}") + raise e else: return type_str, val @@ -267,3 +275,31 @@ def normalize_bytecode(bytecode: bytes) -> HexBytes: bytecode = HexBytes(bytecode) # type ignored b/c bytecode is converted to HexBytes above return bytecode # type: ignore + + +# --- async -- # + + +async def async_abi_ens_resolver( + async_w3: "AsyncWeb3", + type_str: TypeStr, + val: Any, +) -> Tuple[TypeStr, Any]: + if type_str == "address" and is_ens_name(val): + if async_w3 is None: + raise InvalidAddress( + f"Could not look up name {val!r} because no web3" + " connection available" + ) + + _async_ens = cast(AsyncENS, async_w3.ens) + if _async_ens is None: + raise InvalidAddress( + f"Could not look up name {val!r} because ENS is set to None" + ) + else: + address = await async_validate_name_has_address(_async_ens, val) + return type_str, address + + else: + return type_str, val diff --git a/web3/_utils/rpc_abi.py b/web3/_utils/rpc_abi.py index 3c640b53bd..e713858c5e 100644 --- a/web3/_utils/rpc_abi.py +++ b/web3/_utils/rpc_abi.py @@ -5,6 +5,7 @@ Iterable, Sequence, Tuple, + Union, ) from eth_typing import ( @@ -185,7 +186,7 @@ class RPC: "count": "int", } -RPC_ABIS = { +RPC_ABIS: Dict[str, Union[Sequence[Any], Dict[str, str]]] = { # eth "eth_call": TRANSACTION_PARAMS_ABIS, "eth_estimateGas": TRANSACTION_PARAMS_ABIS, diff --git a/web3/manager.py b/web3/manager.py index 580cbc203d..383f7ac7f2 100644 --- a/web3/manager.py +++ b/web3/manager.py @@ -30,6 +30,7 @@ async_attrdict_middleware, async_buffered_gas_estimate_middleware, async_gas_price_strategy_middleware, + async_name_to_address_middleware, async_validation_middleware, attrdict_middleware, buffered_gas_estimate_middleware, @@ -139,7 +140,7 @@ def default_middlewares(w3: "Web3") -> List[Tuple[Middleware, str]]: """ return [ (gas_price_strategy_middleware, "gas_price_strategy"), - (name_to_address_middleware(w3), "name_to_address"), # Add Async + (name_to_address_middleware(w3), "name_to_address"), (attrdict_middleware, "attrdict"), (validation_middleware, "validation"), (abi_middleware, "abi"), @@ -154,6 +155,7 @@ def async_default_middlewares() -> List[Tuple[AsyncMiddleware, str]]: """ return [ (async_gas_price_strategy_middleware, "gas_price_strategy"), + (async_name_to_address_middleware, "name_to_address"), (async_attrdict_middleware, "attrdict"), (async_validation_middleware, "validation"), (async_buffered_gas_estimate_middleware, "gas_estimate"), diff --git a/web3/middleware/__init__.py b/web3/middleware/__init__.py index 607ee27d79..a0b7440e97 100644 --- a/web3/middleware/__init__.py +++ b/web3/middleware/__init__.py @@ -66,6 +66,7 @@ geth_poa_middleware, ) from .names import ( # noqa: F401 + async_name_to_address_middleware, name_to_address_middleware, ) from .normalize_request_parameters import ( # noqa: F401 diff --git a/web3/middleware/names.py b/web3/middleware/names.py index 947863bd1f..df9655a8cb 100644 --- a/web3/middleware/names.py +++ b/web3/middleware/names.py @@ -1,24 +1,47 @@ from typing import ( TYPE_CHECKING, + Any, + Callable, + Dict, + Sequence, + Union, +) + +from toolz import ( + merge, ) from web3._utils.normalizers import ( abi_ens_resolver, + async_abi_ens_resolver, ) from web3._utils.rpc_abi import ( RPC_ABIS, abi_request_formatters, ) from web3.types import ( + AsyncMiddlewareCoroutine, Middleware, + RPCEndpoint, ) +from .._utils.abi import ( + abi_data_tree, + async_data_tree_map, + strip_abi_type, +) +from .._utils.formatters import ( + recursive_map, +) from .formatting import ( construct_formatting_middleware, ) if TYPE_CHECKING: - from web3 import Web3 # noqa: F401 + from web3 import ( # noqa: F401 + AsyncWeb3, + Web3, + ) def name_to_address_middleware(w3: "Web3") -> Middleware: @@ -28,3 +51,72 @@ def name_to_address_middleware(w3: "Web3") -> Middleware: return construct_formatting_middleware( request_formatters=abi_request_formatters(normalizers, RPC_ABIS) ) + + +# -- async -- # + + +async def async_format_all_ens_names_to_address( + async_web3: "AsyncWeb3", + abi_types_for_method: Sequence[Any], + data: Sequence[Any], +) -> Sequence[Any]: + # provide a stepwise version of what the curried formatters do + abi_typed_params = abi_data_tree(abi_types_for_method, data) + formatted_data_tree = await async_data_tree_map( + async_web3, + async_abi_ens_resolver, + abi_typed_params, + ) + formatted_params = recursive_map(strip_abi_type, formatted_data_tree) + return formatted_params + + +async def async_apply_ens_to_address_conversion( + async_web3: "AsyncWeb3", + params: Any, + abi_types_for_method: Union[Sequence[str], Dict[str, str]], +) -> Any: + if isinstance(abi_types_for_method, Sequence): + formatted_params = await async_format_all_ens_names_to_address( + async_web3, abi_types_for_method, params + ) + return formatted_params + + elif isinstance(abi_types_for_method, dict): + # first arg is a dict but other args may be preset + # e.g. eth_call({...}, "latest") + # this is similar to applying a dict formatter at index 0 of the args + param_dict = params[0] + fields = list(abi_types_for_method.keys() & param_dict.keys()) + formatted_params = await async_format_all_ens_names_to_address( + async_web3, + [abi_types_for_method[field] for field in fields], + [param_dict[field] for field in fields], + ) + formatted_dict = dict(zip(fields, formatted_params)) + formatted_params_dict = merge(param_dict, formatted_dict) + return (formatted_params_dict, *params[1:]) + + else: + raise TypeError( + f"ABI definitions must be a list or dictionary, " + f"got {abi_types_for_method!r}" + ) + + +async def async_name_to_address_middleware( + make_request: Callable[[RPCEndpoint, Any], Any], + async_w3: "AsyncWeb3", +) -> AsyncMiddlewareCoroutine: + async def middleware(method: RPCEndpoint, params: Any) -> Any: + abi_types_for_method = RPC_ABIS.get(method, None) + if abi_types_for_method is not None: + params = await async_apply_ens_to_address_conversion( + async_w3, + params, + abi_types_for_method, + ) + return await make_request(method, params) + + return middleware