Skip to content

Commit 82383bc

Browse files
committed
Mirror sync logic for async name to address
- There wasn't too much code copying in the end with this approach so this seems better. This provides proper recursion for address types and gives async and sync the same treatment as they basically follow the same approach.
1 parent 3ab1509 commit 82383bc

File tree

2 files changed

+92
-19
lines changed

2 files changed

+92
-19
lines changed

web3/_utils/abi.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import itertools
88
import re
99
from typing import (
10+
TYPE_CHECKING,
1011
Any,
1112
Callable,
1213
Collection,
14+
Coroutine,
1315
Dict,
1416
Iterable,
1517
List,
@@ -53,6 +55,7 @@
5355
decode_hex,
5456
is_bytes,
5557
is_list_like,
58+
is_string,
5659
is_text,
5760
to_text,
5861
to_tuple,
@@ -66,6 +69,9 @@
6669
pipe,
6770
)
6871

72+
from web3._utils.decorators import (
73+
reject_recursive_repeats,
74+
)
6975
from web3._utils.ens import (
7076
is_ens_name,
7177
)
@@ -82,11 +88,17 @@
8288
ABIEventParams,
8389
ABIFunction,
8490
ABIFunctionParams,
91+
TReturn,
8592
)
8693
from web3.utils import ( # public utils module
8794
get_abi_input_names,
8895
)
8996

97+
if TYPE_CHECKING:
98+
from web3 import ( # noqa: F401
99+
AsyncWeb3,
100+
)
101+
90102

91103
def filter_by_type(_type: str, contract_abi: ABI) -> List[Union[ABIFunction, ABIEvent]]:
92104
return [abi for abi in contract_abi if abi["type"] == _type]
@@ -971,3 +983,70 @@ def __new__(self, args: Any) -> "ABIDecodedNamedTuple":
971983
return super().__new__(self, *args)
972984

973985
return ABIDecodedNamedTuple
986+
987+
988+
# -- async -- #
989+
990+
991+
async def async_data_tree_map(
992+
async_w3: "AsyncWeb3",
993+
func: Callable[
994+
["AsyncWeb3", TypeStr, Any], Coroutine[Any, Any, Tuple[TypeStr, Any]]
995+
],
996+
data_tree: Any,
997+
) -> "ABITypedData":
998+
"""
999+
Map an awaitable method to every ABITypedData element in the tree.
1000+
1001+
The awaitable method should receive three positional args:
1002+
async_w3, abi_type, and data
1003+
"""
1004+
1005+
async def async_map_to_typed_data(elements: Any) -> "ABITypedData":
1006+
if isinstance(elements, ABITypedData) and elements.abi_type is not None:
1007+
formatted = await func(async_w3, *elements)
1008+
return ABITypedData(formatted)
1009+
else:
1010+
return elements
1011+
1012+
return await async_recursive_map(async_w3, async_map_to_typed_data, data_tree)
1013+
1014+
1015+
@reject_recursive_repeats
1016+
async def async_recursive_map(
1017+
async_w3: "AsyncWeb3",
1018+
func: Callable[[Any], Coroutine[Any, Any, TReturn]],
1019+
data: Any,
1020+
) -> TReturn:
1021+
"""
1022+
Apply an awaitable method to data and any collection items inside data
1023+
(using async_map_collection).
1024+
1025+
Define the awaitable method so that it only applies to the type of value that you
1026+
want it to apply to.
1027+
"""
1028+
1029+
async def async_recurse(item: Any) -> TReturn:
1030+
return await async_recursive_map(async_w3, func, item)
1031+
1032+
items_mapped = await async_map_if_collection(async_recurse, data)
1033+
return await func(items_mapped)
1034+
1035+
1036+
async def async_map_if_collection(
1037+
func: Callable[[Any], Coroutine[Any, Any, Any]], value: Any
1038+
) -> Any:
1039+
"""
1040+
Apply an awaitable method to each element of a collection or value of a dictionary.
1041+
If the value is not a collection, return it unmodified.
1042+
"""
1043+
1044+
datatype = type(value)
1045+
if isinstance(value, Mapping):
1046+
return datatype({key: await func(val) for key, val in value.values()})
1047+
if is_string(value):
1048+
return value
1049+
elif isinstance(value, Iterable):
1050+
return datatype([await func(item) for item in value])
1051+
else:
1052+
return value

web3/middleware/names.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@
2323

2424
from .._utils.abi import (
2525
abi_data_tree,
26+
async_data_tree_map,
27+
strip_abi_type,
28+
)
29+
from .._utils.formatters import (
30+
recursive_map,
2631
)
2732
from .formatting import (
2833
construct_formatting_middleware,
@@ -52,26 +57,14 @@ async def async_format_all_ens_names_to_address(
5257
abi_types_for_method: Sequence[Any],
5358
data: Sequence[Any],
5459
) -> Sequence[Any]:
60+
# provide a stepwise version of what the curried formatters do
5561
abi_typed_params = abi_data_tree(abi_types_for_method, data)
56-
57-
formatted_params = []
58-
for param in abi_typed_params:
59-
if param.abi_type == "address[]":
60-
# handle name conversion in an address list
61-
# Note: only supports single list atm, as is true the sync middleware
62-
# TODO: handle address[][], etc...
63-
formatted_data = await async_format_all_ens_names_to_address(
64-
async_web3,
65-
[param.abi_type[:-2]] * len(param.data),
66-
[subparam.data for subparam in param.data],
67-
)
68-
else:
69-
_abi_type, formatted_data = await async_abi_ens_resolver(
70-
async_web3,
71-
param.abi_type,
72-
param.data,
73-
)
74-
formatted_params.append(formatted_data)
62+
formatted_data_tree = await async_data_tree_map(
63+
async_web3,
64+
async_abi_ens_resolver,
65+
abi_typed_params,
66+
)
67+
formatted_params = recursive_map(strip_abi_type, formatted_data_tree)
7568
return formatted_params
7669

7770

@@ -96,6 +89,7 @@ async def async_apply_ens_to_address_conversion(
9689
)
9790
formatted_dict = dict(zip(fields, formatted_params))
9891
return (formatted_dict,)
92+
9993
else:
10094
raise TypeError(
10195
f"ABI definitions must be a list or dictionary, "

0 commit comments

Comments
 (0)