Skip to content

Add rich tuple decoder #1353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 23 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
e8f42aa
Add rich tuple decoder
banteg May 13, 2019
3891f4f
Add decode option to contract
banteg May 14, 2019
7e24a3a
Add tests for named tuple decoder
banteg May 15, 2019
1bdd964
Use dict comprehension when building decoded tuple
banteg May 15, 2019
488a194
Fix tuple decoding in decode_function_input
banteg May 15, 2019
147d98b
Decode tuples as namedtuples instead of dicts
banteg May 16, 2019
70790d3
Add foldable namedtuple such that type(x)(x) == x
banteg May 16, 2019
aa8bddb
Add decode_arguments function that deals with top-level names
banteg May 16, 2019
fcbc528
Use named_arguments_tuple for decoding function inputs/outputs
banteg May 16, 2019
822ad64
Move decode_transaction_data to utils, make it more testable
banteg May 16, 2019
3c08f1e
Add tests for decoding transaction data
banteg May 16, 2019
1e42dda
Make tuples anonymous
banteg May 16, 2019
ade0669
Rediscover old tests, remove duplicate tests
banteg May 16, 2019
4641c9a
Add literal namedtuple constructor
banteg May 17, 2019
4c109c5
Strip leading underscore in namedtuple field names
banteg May 17, 2019
448b1c3
Fallback to tuple on named fields clash
banteg May 18, 2019
5d17e43
Revert decode_function_input test change
banteg May 18, 2019
0a2f931
Add ability to convert namedtuple to dict
banteg May 18, 2019
91c0d36
Don't try to create namedtuple with reserved keywords in fields
banteg May 18, 2019
a01aa99
Use dict when parsing argument names
banteg May 18, 2019
62d11c8
Add tuple support and decoding to events
banteg May 18, 2019
37d7c32
Update named_tree tests
banteg May 18, 2019
e328c5b
fix vyper-specific scalar as struct output
banteg Jul 9, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions tests/core/utilities/test_abi_named_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from web3._utils.abi import (
check_if_arguments_can_be_encoded,
dict_to_namedtuple,
foldable_namedtuple,
named_tree,
)

from .test_abi import (
TEST_FUNCTION_ABI,
)

abi = TEST_FUNCTION_ABI['inputs']

# s = (a=1, b=[2, 3, 4], c=[(x=5, y=6), (x=7, y=8), (x=9, y=10)])
# t = (x=11, y=12)
# a = 13
inputs = (
(1, [2, 3, 4], [(5, 6), (7, 8), (9, 10)]),
(11, 12),
13,
)


def test_named_arguments_decode():
decoded = named_tree(abi, inputs)
data = dict_to_namedtuple(decoded)
assert data == inputs
assert data.s.c[2].y == 10
assert data.t.x == 11
assert data.a == 13


def test_namedtuples_encodable():
kwargs = named_tree(abi, inputs)
args = dict_to_namedtuple(kwargs)
assert check_if_arguments_can_be_encoded(TEST_FUNCTION_ABI, args, {})
assert check_if_arguments_can_be_encoded(TEST_FUNCTION_ABI, (), kwargs)


def test_foldable_namedtuple():
item = foldable_namedtuple(['a', 'b', 'c'])([1, 2, 3])
assert type(item)(item) == item == (1, 2, 3)
assert item.c == 3
64 changes: 64 additions & 0 deletions web3/_utils/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
)
import copy
import itertools
from keyword import (
kwlist,
)
import re
from typing import (
Any,
Expand Down Expand Up @@ -711,3 +714,64 @@ def strip_abi_type(elements):
return elements.data
else:
return elements


def named_tree(abi, data: tuple):
"""
Convert function inputs/outputs or event data tuple to dict with names taken from ABI.
"""
names = [item['name'] for item in abi]
items = [named_subtree(*item) for item in zip(abi, data)]
return dict(zip(names, items)) if all(names) else items


def named_subtree(abi, data):
abi_type = parse(collapse_if_tuple(abi))

if abi_type.is_array:
item_type = abi_type.item_type.to_type_str()
item_abi = {**abi, 'type': item_type, 'name': ''}
items = [named_subtree(item_abi, item) for item in data]
return items

if isinstance(abi_type, TupleType):
names = [item['name'] for item in abi['components']]
items = [named_subtree(*item) for item in zip(abi['components'], data)]
return dict(zip(names, items))

return data


def dict_to_namedtuple(data):
def to_tuple(item):
if isinstance(item, dict):
return Tuple(**item)
return item

return recursive_map(to_tuple, data)


def foldable_namedtuple(fields):
"""
Customized namedtuple such that `type(x)(x) == x`.
"""
fields = [field.lstrip('_') for field in fields]
if '' in fields or len(set(fields)) < len(fields) or set(fields) & set(kwlist):
return tuple

class Tuple(namedtuple('Tuple', fields)):
def __new__(self, args):
return super().__new__(self, *args)

def _asdict(self):
return dict(super()._asdict())

return Tuple


def Tuple(**kwargs):
"""
Literal namedtuple constructor such that `Tuple(x=1, y=2)` returns `Tuple(x=1, y=2)`.
"""
keys, values = zip(*kwargs.items())
return foldable_namedtuple(keys)(values)
11 changes: 11 additions & 0 deletions web3/_utils/contracts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools

from eth_abi import (
decode_abi,
encode_abi as eth_abi_encode_abi,
)
from eth_utils import (
Expand Down Expand Up @@ -30,6 +31,7 @@
get_fallback_func_abi,
map_abi_data,
merge_args_and_kwargs,
named_tree,
)
from web3._utils.encoding import (
to_hex,
Expand Down Expand Up @@ -220,6 +222,15 @@ def encode_transaction_data(
return add_0x_prefix(encode_abi(web3, fn_abi, fn_arguments, fn_selector))


def decode_transaction_data(fn_abi, data, normalizers=None):
data = HexBytes(data)
types = get_abi_input_types(fn_abi)
decoded = decode_abi(types, data[4:])
if normalizers:
decoded = map_abi_data(normalizers, types, decoded)
return named_tree(fn_abi['inputs'], decoded)


def get_fallback_function_info(contract_abi=None, fn_abi=None):
if fn_abi is None:
fn_abi = get_fallback_func_abi(contract_abi)
Expand Down
12 changes: 10 additions & 2 deletions web3/_utils/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
to_hex,
to_tuple,
)
from eth_utils.abi import (
collapse_if_tuple,
)
from eth_utils.toolz import (
complement,
compose,
Expand Down Expand Up @@ -54,6 +57,7 @@
get_abi_input_names,
get_indexed_event_inputs,
map_abi_data,
named_tree,
normalize_event_input_types,
)

Expand Down Expand Up @@ -152,7 +156,7 @@ def get_event_abi_types_for_decoding(event_inputs):
if input_abi['indexed'] and is_dynamic_sized_type(input_abi['type']):
yield 'bytes32'
else:
yield input_abi['type']
yield collapse_if_tuple(input_abi)


@curry
Expand Down Expand Up @@ -202,6 +206,10 @@ def get_event_data(event_abi, log_entry):
log_data_types,
decoded_log_data
)
named_log_data = named_tree(
log_data_normalized_inputs,
normalized_log_data,
)

decoded_topic_data = [
decode_single(topic_type, topic_data)
Expand All @@ -216,7 +224,7 @@ def get_event_data(event_abi, log_entry):

event_args = dict(itertools.chain(
zip(log_topic_names, normalized_topic_data),
zip(log_data_names, normalized_log_data),
named_log_data.items(),
))

event_data = {
Expand Down
46 changes: 31 additions & 15 deletions web3/contract.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""Interaction with smart contracts over Web3 connector.

"""
from collections.abc import (
Sequence,
)
import copy
import itertools

Expand Down Expand Up @@ -29,18 +32,21 @@
from web3._utils.abi import (
abi_to_signature,
check_if_arguments_can_be_encoded,
dict_to_namedtuple,
fallback_func_abi_exists,
filter_by_type,
get_abi_output_types,
get_constructor_abi,
is_array_type,
map_abi_data,
merge_args_and_kwargs,
named_tree,
)
from web3._utils.blocks import (
is_hex_encoded_block_hash,
)
from web3._utils.contracts import (
decode_transaction_data,
encode_abi,
find_matching_event_abi,
find_matching_fn_abi,
Expand Down Expand Up @@ -99,7 +105,7 @@ class ContractFunctions:
"""Class containing contract function objects
"""

def __init__(self, abi, web3, address=None):
def __init__(self, abi, web3, address=None, decode=False):
self.abi = abi
self.web3 = web3
self.address = address
Expand All @@ -115,7 +121,8 @@ def __init__(self, abi, web3, address=None):
web3=self.web3,
contract_abi=self.abi,
address=self.address,
function_identifier=func['name']))
function_identifier=func['name'],
decode=decode))

def __iter__(self):
if not hasattr(self, '_functions') or not self._functions:
Expand Down Expand Up @@ -243,6 +250,7 @@ class Contract:

functions = None
caller = None
decode = None

#: Instance of :class:`ContractEvents` presenting available Event ABIs
events = None
Expand Down Expand Up @@ -272,8 +280,8 @@ def __init__(self, address=None):
if not self.address:
raise TypeError("The address argument is required to instantiate a contract.")

self.functions = ContractFunctions(self.abi, self.web3, self.address)
self.caller = ContractCaller(self.abi, self.web3, self.address)
self.functions = ContractFunctions(self.abi, self.web3, self.address, decode=self.decode)
self.caller = ContractCaller(self.abi, self.web3, self.address, decode=self.decode)
self.events = ContractEvents(self.abi, self.web3, self.address)
self.fallback = Contract.get_fallback_function(self.abi, self.web3, self.address)

Expand All @@ -295,6 +303,7 @@ def factory(cls, web3, class_name=None, **kwargs):
kwargs,
normalizers=normalizers,
)
contract.decode = kwargs.get('decode', False)
contract.functions = ContractFunctions(contract.abi, contract.web3)
contract.caller = ContractCaller(contract.abi, contract.web3, contract.address)
contract.events = ContractEvents(contract.abi, contract.web3)
Expand Down Expand Up @@ -388,13 +397,9 @@ def callable_check(fn_abi):
@combomethod
def decode_function_input(self, data):
data = HexBytes(data)
selector, params = data[:4], data[4:]
func = self.get_function_by_selector(selector)
names = [x['name'] for x in func.abi['inputs']]
types = [x['type'] for x in func.abi['inputs']]
decoded = decode_abi(types, params)
normalized = map_abi_data(BASE_RETURN_NORMALIZERS, types, decoded)
return func, dict(zip(names, normalized))
func = self.get_function_by_selector(data[:4])
arguments = decode_transaction_data(func.abi, data, normalizers=BASE_RETURN_NORMALIZERS)
return func, arguments

@combomethod
def find_functions_by_args(self, *args):
Expand Down Expand Up @@ -722,6 +727,7 @@ class ContractFunction:
abi = None
transaction = None
arguments = None
decode = None

def __init__(self, abi=None):
self.abi = abi
Expand Down Expand Up @@ -819,6 +825,7 @@ def call(self, transaction=None, block_identifier='latest'):
block_id,
self.contract_abi,
self.abi,
self.decode,
*self.args,
**self.kwargs
)
Expand Down Expand Up @@ -1192,11 +1199,13 @@ def __init__(self,
web3,
address,
transaction=None,
block_identifier='latest'):
block_identifier='latest',
decode=False):
self.web3 = web3
self.address = address
self.abi = abi
self._functions = None
self.decode = decode

if self.abi:
if transaction is None:
Expand All @@ -1209,7 +1218,8 @@ def __init__(self,
web3=self.web3,
contract_abi=self.abi,
address=self.address,
function_identifier=func['name'])
function_identifier=func['name'],
decode=decode)

block_id = parse_block_identifier(self.web3, block_identifier)
caller_method = partial(self.call_function,
Expand Down Expand Up @@ -1247,7 +1257,8 @@ def __call__(self, transaction=None, block_identifier='latest'):
self.web3,
self.address,
transaction=transaction,
block_identifier=block_identifier)
block_identifier=block_identifier,
decode=self.decode)

@staticmethod
def call_function(fn, *args, transaction=None, block_identifier='latest', **kwargs):
Expand Down Expand Up @@ -1281,6 +1292,7 @@ def call_contract_function(
block_id=None,
contract_abi=None,
fn_abi=None,
decode=False,
*args,
**kwargs):
"""
Expand Down Expand Up @@ -1339,7 +1351,11 @@ def call_contract_function(
)
normalized_data = map_abi_data(_normalizers, output_types, output_data)

if len(normalized_data) == 1:
if decode:
decoded = named_tree(fn_abi['outputs'], normalized_data)
normalized_data = dict_to_namedtuple(decoded)

if isinstance(normalized_data, Sequence) and len(normalized_data) == 1:
return normalized_data[0]
else:
return normalized_data
Expand Down