Skip to content

Feature/asyncify contract #2387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Mar 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ens/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def resolver(self, normal_name: str) -> Optional['Contract']:
resolver_addr = self.ens.caller.resolver(normal_name_to_hash(normal_name))
if is_none_or_zero_address(resolver_addr):
return None
return self._resolverContract(address=resolver_addr)
# TODO: look at possibly removing type ignore when AsyncENS is written
return self._resolverContract(address=resolver_addr) # type: ignore

def reverser(self, target_address: ChecksumAddress) -> Optional['Contract']:
reversed_domain = address_to_reverse_domain(target_address)
Expand Down Expand Up @@ -393,7 +394,8 @@ def _set_resolver(
namehash,
resolver_addr
).transact(transact)
return self._resolverContract(address=resolver_addr)
# TODO: look at possibly removing type ignore when AsyncENS is written
return self._resolverContract(address=resolver_addr) # type: ignore

def _setup_reverse(
self, name: str, address: ChecksumAddress, transact: Optional["TxParams"] = None
Expand Down
3 changes: 2 additions & 1 deletion ethpm/package.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,8 @@ def get_contract_instance(self, name: ContractName, address: Address) -> Contrac
contract_instance = self.w3.eth.contract(
address=address, **contract_kwargs
)
return contract_instance
# TODO: type ignore may be able to be removed after more of AsynContract is finished
return contract_instance # type: ignore

#
# Build Dependencies
Expand Down
7 changes: 3 additions & 4 deletions tests/core/contracts/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,16 +1052,15 @@ def buildTransaction(request):
return functools.partial(invoke_contract, api_call_desig='buildTransaction')


@pytest_asyncio.fixture()
async def async_deploy(web3, Contract, apply_func=identity, args=None):
async def async_deploy(async_web3, Contract, apply_func=identity, args=None):
args = args or []
deploy_txn = await Contract.constructor(*args).transact()
deploy_receipt = await web3.eth.wait_for_transaction_receipt(deploy_txn)
deploy_receipt = await async_web3.eth.wait_for_transaction_receipt(deploy_txn)
assert deploy_receipt is not None
address = apply_func(deploy_receipt['contractAddress'])
contract = Contract(address=address)
assert contract.address == address
assert len(await web3.eth.get_code(contract.address)) > 0
assert len(await async_web3.eth.get_code(contract.address)) > 0
return contract


Expand Down
28 changes: 28 additions & 0 deletions tests/core/contracts/test_contract_call_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,3 +879,31 @@ def test_call_revert_contract(revert_contract):
# which does not contain the revert reason. Avoid that by giving a gas
# value.
revert_contract.functions.revertWithMessage().call({'gas': 100000})


@pytest.mark.asyncio
async def test_async_call_with_no_arguments(async_math_contract, call):
result = await async_math_contract.functions.return13().call()
assert result == 13


@pytest.mark.asyncio
async def test_async_call_with_one_argument(async_math_contract, call):
result = await async_math_contract.functions.multiply7(3).call()
assert result == 21


@pytest.mark.asyncio
async def test_async_returns_data_from_specified_block(async_w3, async_math_contract):
start_num = await async_w3.eth.get_block('latest')
await async_w3.provider.make_request(method='evm_mine', params=[5])
await async_math_contract.functions.increment().transact()
await async_math_contract.functions.increment().transact()

output1 = await async_math_contract.functions.counter().call(
block_identifier=start_num.number + 6)
output2 = await async_math_contract.functions.counter().call(
block_identifier=start_num.number + 7)

assert output1 == 1
assert output2 == 2
2 changes: 1 addition & 1 deletion web3/_utils/module_testing/web3_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def test_solidityKeccak(
self, w3: "Web3", types: Sequence[TypeStr], values: Sequence[Any], expected: HexBytes
) -> None:
if isinstance(expected, type) and issubclass(expected, Exception):
with pytest.raises(expected):
with pytest.raises(expected): # type: ignore
w3.solidityKeccak(types, values)
return

Expand Down
67 changes: 44 additions & 23 deletions web3/contract.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,25 +840,10 @@ def _get_transaction(self, transaction: Optional[TxParams] = None) -> TxParams:
return transact_transaction

@combomethod
def buildTransaction(self, transaction: Optional[TxParams] = None) -> TxParams:
"""
Build the transaction dictionary without sending
"""

if transaction is None:
built_transaction: TxParams = {}
else:
built_transaction = cast(TxParams, dict(**transaction))
self.check_forbidden_keys_in_transaction(built_transaction,
["data", "to"])

if self.w3.eth.default_account is not empty:
# type ignored b/c check prevents an empty default_account
built_transaction.setdefault('from', self.w3.eth.default_account) # type: ignore

built_transaction['data'] = self.data_in_transaction
def _build_transaction(self, transaction: Optional[TxParams] = None) -> TxParams:
built_transaction = self._get_transaction(transaction)
built_transaction['to'] = Address(b'')
return fill_transaction_defaults(self.w3, built_transaction)
return built_transaction

@staticmethod
def check_forbidden_keys_in_transaction(
Expand All @@ -877,6 +862,22 @@ class ContractConstructor(BaseContractConstructor):
def transact(self, transaction: Optional[TxParams] = None) -> HexBytes:
return self.w3.eth.send_transaction(self._get_transaction(transaction))

@combomethod
def build_transaction(self, transaction: Optional[TxParams] = None) -> TxParams:
"""
Build the transaction dictionary without sending
"""
built_transaction = self._build_transaction(transaction)
return fill_transaction_defaults(self.w3, built_transaction)

@combomethod
@deprecated_for("build_transaction")
def buildTransaction(self, transaction: Optional[TxParams] = None) -> TxParams:
"""
Build the transaction dictionary without sending
"""
return self.build_transaction(transaction)


class AsyncContractConstructor(BaseContractConstructor):

Expand All @@ -885,6 +886,14 @@ async def transact(self, transaction: Optional[TxParams] = None) -> HexBytes:
return await self.w3.eth.send_transaction( # type: ignore
self._get_transaction(transaction))

@combomethod
async def build_transaction(self, transaction: Optional[TxParams] = None) -> TxParams:
"""
Build the transaction dictionary without sending
"""
built_transaction = self._build_transaction(transaction)
return fill_transaction_defaults(self.w3, built_transaction)


class ConciseMethod:
ALLOWED_MODIFIERS = {'call', 'estimateGas', 'transact', 'buildTransaction'}
Expand Down Expand Up @@ -1364,7 +1373,7 @@ async def call(
self._return_data_normalizers,
self.function_identifier,
call_transaction,
block_id,
block_id, # type: ignore
self.contract_abi,
self.abi,
state_override,
Expand Down Expand Up @@ -2020,11 +2029,11 @@ def parse_block_identifier(w3: 'Web3', block_identifier: BlockIdentifier) -> Blo

async def async_parse_block_identifier(w3: 'Web3',
block_identifier: BlockIdentifier
) -> BlockIdentifier:
) -> Awaitable[BlockIdentifier]:
if isinstance(block_identifier, int):
return parse_block_identifier_int(w3, block_identifier)
return await async_parse_block_identifier_int(w3, block_identifier)
elif block_identifier in ['latest', 'earliest', 'pending']:
return block_identifier
return block_identifier # type: ignore
elif isinstance(block_identifier, bytes) or is_hex_encoded_block_hash(block_identifier):
return await w3.eth.get_block(block_identifier)['number'] # type: ignore
else:
Expand All @@ -2042,6 +2051,18 @@ def parse_block_identifier_int(w3: 'Web3', block_identifier_int: int) -> BlockNu
return BlockNumber(block_num)


async def async_parse_block_identifier_int(w3: 'Web3', block_identifier_int: int
) -> Awaitable[BlockNumber]:
if block_identifier_int >= 0:
block_num = block_identifier_int
else:
last_block = await w3.eth.get_block('latest')['number'] # type: ignore
block_num = last_block + block_identifier_int + 1
if block_num < 0:
raise BlockNumberOutofRange
return BlockNumber(block_num) # type: ignore


def transact_with_contract_function(
address: ChecksumAddress,
w3: 'Web3',
Expand Down Expand Up @@ -2167,7 +2188,7 @@ def async_find_functions_by_identifier(

def get_function_by_identifier(
fns: Sequence[ContractFunction], identifier: str
) -> ContractFunction:
) -> Union[ContractFunction, AsyncContractFunction]:
if len(fns) > 1:
raise ValueError(
'Found multiple functions with matching {0}. '
Expand Down
52 changes: 28 additions & 24 deletions web3/eth.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@
replace_transaction,
)
from web3.contract import (
AsyncContract,
AsyncContractCaller,
ConciseContract,
Contract,
ContractCaller,
Expand Down Expand Up @@ -116,6 +118,8 @@ class BaseEth(Module):
_default_block: BlockIdentifier = "latest"
_default_chain_id: Optional[int] = None
gasPriceStrategy = None
defaultContractFactory: Type[Union[Contract, AsyncContract,
ConciseContract, ContractCaller, AsyncContractCaller]] = Contract

_gas_price: Method[Callable[[], Wei]] = Method(
RPC.eth_gasPrice,
Expand Down Expand Up @@ -343,6 +347,30 @@ def call_munger(
mungers=[default_root_munger]
)

@overload
def contract(self, address: None = None, **kwargs: Any) -> Union[Type[Contract], Type[AsyncContract]]: ... # noqa: E704,E501

@overload # noqa: F811
def contract(self, address: Union[Address, ChecksumAddress, ENS], **kwargs: Any) -> Union[Contract, AsyncContract]: ... # noqa: E704,E501

def contract( # noqa: F811
self, address: Optional[Union[Address, ChecksumAddress, ENS]] = None, **kwargs: Any
) -> Union[Type[Contract], Contract, Type[AsyncContract], AsyncContract]:
ContractFactoryClass = kwargs.pop('ContractFactoryClass', self.defaultContractFactory)

ContractFactory = ContractFactoryClass.factory(self.w3, **kwargs)

if address:
return ContractFactory(address)
else:
return ContractFactory

def set_contract_factory(
self, contractFactory: Type[Union[Contract, AsyncContract,
ConciseContract, ContractCaller, AsyncContractCaller]]
) -> None:
self.defaultContractFactory = contractFactory


class AsyncEth(BaseEth):
is_async = True
Expand Down Expand Up @@ -554,7 +582,6 @@ async def call(

class Eth(BaseEth):
account = Account()
defaultContractFactory: Type[Union[Contract, ConciseContract, ContractCaller]] = Contract # noqa: E704,E501
iban = Iban

def namereg(self) -> NoReturn:
Expand Down Expand Up @@ -949,35 +976,12 @@ def filter_munger(
mungers=[default_root_munger],
)

@overload
def contract(self, address: None = None, **kwargs: Any) -> Type[Contract]: ... # noqa: E704,E501

@overload # noqa: F811
def contract(self, address: Union[Address, ChecksumAddress, ENS], **kwargs: Any) -> Contract: ... # noqa: E704,E501

def contract( # noqa: F811
self, address: Optional[Union[Address, ChecksumAddress, ENS]] = None, **kwargs: Any
) -> Union[Type[Contract], Contract]:
ContractFactoryClass = kwargs.pop('ContractFactoryClass', self.defaultContractFactory)

ContractFactory = ContractFactoryClass.factory(self.w3, **kwargs)

if address:
return ContractFactory(address)
else:
return ContractFactory

@deprecated_for("set_contract_factory")
def setContractFactory(
self, contractFactory: Type[Union[Contract, ConciseContract, ContractCaller]]
) -> None:
return self.set_contract_factory(contractFactory)

def set_contract_factory(
self, contractFactory: Type[Union[Contract, ConciseContract, ContractCaller]]
) -> None:
self.defaultContractFactory = contractFactory

def getCompilers(self) -> NoReturn:
raise DeprecationWarning("This method has been deprecated as of EIP 1474.")

Expand Down