Skip to content

Clean up AsyncHTTPProvider instantiation #2736

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 4 additions & 12 deletions docs/providers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -282,18 +282,10 @@ AsyncHTTPProvider
>>> from web3.net import AsyncNet
>>> from web3.geth import Geth, AsyncGethTxPool

>>> w3 = Web3(
... AsyncHTTPProvider(endpoint_uri),
... modules={'eth': (AsyncEth,),
... 'net': (AsyncNet,),
... 'geth': (Geth,
... {'txpool': (AsyncGethTxPool,),
... 'personal': (AsyncGethPersonal,),
... 'admin' : (AsyncGethAdmin,)})
... },
... middlewares=[] # See supported middleware section below for middleware options
... )
>>> custom_session = ClientSession() # If you want to pass in your own session
>>> w3 = Web3(AsyncHTTPProvider(endpoint_uri))

>>> # If you want to pass in your own session:
>>> custom_session = ClientSession()
>>> await w3.provider.cache_async_session(custom_session) # This method is an async method so it needs to be handled accordingly

Under the hood, the ``AsyncHTTPProvider`` uses the python
Expand Down
1 change: 1 addition & 0 deletions newsfragments/2736.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Load the ``AsyncHTTPProvider`` with default async middleware and default async modules, just as the ``HTTPProvider``.
9 changes: 5 additions & 4 deletions tests/core/contracts/test_contract_call_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
FallbackNotFound,
InvalidAddress,
MismatchedABI,
NameNotFound,
NoABIFound,
NoABIFunctionsFound,
ValidationError,
Expand Down Expand Up @@ -490,9 +491,9 @@ def test_call_address_reflector_name_array(address_reflector_contract, call):
assert addresses == result


def test_call_reject_invalid_ens_name(address_reflector_contract, call):
def test_call_rejects_invalid_ens_name(address_reflector_contract, call):
with contract_ens_addresses(address_reflector_contract, []):
with pytest.raises(ValueError):
with pytest.raises(NameNotFound):
call(
contract=address_reflector_contract,
contract_function="reflect",
Expand Down Expand Up @@ -1363,11 +1364,11 @@ async def test_async_call_address_reflector_name_array(

@pytest.mark.xfail
@pytest.mark.asyncio
async def test_async_call_reject_invalid_ens_name(
async def test_async_call_rejects_invalid_ens_name(
async_address_reflector_contract, async_call
):
with contract_ens_addresses(async_address_reflector_contract, []):
with pytest.raises(ValueError):
with pytest.raises(NameNotFound):
await async_call(
contract=async_address_reflector_contract,
contract_function="reflect",
Expand Down
64 changes: 63 additions & 1 deletion tests/core/providers/test_async_http_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,80 @@
ClientSession,
)

from web3 import Web3
from web3._utils import (
request,
)
from web3.eth import (
AsyncEth,
)
from web3.geth import (
AsyncGethAdmin,
AsyncGethPersonal,
AsyncGethTxPool,
Geth,
)
from web3.middleware import (
async_buffered_gas_estimate_middleware,
async_gas_price_strategy_middleware,
async_validation_middleware,
)
from web3.net import (
AsyncNet,
)
from web3.providers.async_rpc import (
AsyncHTTPProvider,
)

URI = "http://mynode.local:8545"


def test_no_args():
provider = AsyncHTTPProvider()
w3 = Web3(provider)
assert w3.manager.provider == provider
assert w3.manager.provider.is_async


def test_init_kwargs():
provider = AsyncHTTPProvider(endpoint_uri=URI, request_kwargs={"timeout": 60})
w3 = Web3(provider)
assert w3.manager.provider == provider


def test_web3_with_async_http_provider_has_default_middlewares_and_modules() -> None:
async_w3 = Web3(AsyncHTTPProvider(endpoint_uri=URI))

# assert default modules

assert isinstance(async_w3.eth, AsyncEth)
assert isinstance(async_w3.net, AsyncNet)
assert isinstance(async_w3.geth, Geth)
assert isinstance(async_w3.geth.admin, AsyncGethAdmin)
assert isinstance(async_w3.geth.personal, AsyncGethPersonal)
assert isinstance(async_w3.geth.txpool, AsyncGethTxPool)

# assert default middleware

# the following length check should fail and will need to be added to once more
# async middlewares are added to the defaults
assert len(async_w3.middleware_onion.middlewares) == 3

assert (
async_w3.middleware_onion.get("gas_price_strategy")
== async_gas_price_strategy_middleware
)
assert async_w3.middleware_onion.get("validation") == async_validation_middleware
assert (
async_w3.middleware_onion.get("gas_estimate")
== async_buffered_gas_estimate_middleware
)


@pytest.mark.asyncio
async def test_user_provided_session() -> None:
session = ClientSession()
provider = AsyncHTTPProvider(endpoint_uri="http://mynode.local:8545")
provider = AsyncHTTPProvider(endpoint_uri=URI)
cached_session = await provider.cache_async_session(session)
assert len(request._async_session_cache) == 1
assert cached_session == session
1 change: 1 addition & 0 deletions tests/core/providers/test_http_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def test_no_args():
provider = HTTPProvider()
w3 = Web3(provider)
assert w3.manager.provider == provider
assert not w3.manager.provider.is_async


def test_init_kwargs():
Expand Down
50 changes: 7 additions & 43 deletions tests/integration/go_ethereum/test_goethereum_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,23 +12,6 @@
from web3._utils.module_testing.go_ethereum_personal_module import (
GoEthereumAsyncPersonalModuleTest,
)
from web3.eth import (
AsyncEth,
)
from web3.geth import (
AsyncGethAdmin,
AsyncGethPersonal,
AsyncGethTxPool,
Geth,
)
from web3.middleware import (
async_buffered_gas_estimate_middleware,
async_gas_price_strategy_middleware,
async_validation_middleware,
)
from web3.net import (
AsyncNet,
)
from web3.providers.async_rpc import (
AsyncHTTPProvider,
)
Expand Down Expand Up @@ -137,26 +120,7 @@ class TestGoEthereumTxPoolModuleTest(GoEthereumTxPoolModuleTest):
@pytest_asyncio.fixture(scope="module")
async def async_w3(geth_process, endpoint_uri):
await wait_for_aiohttp(endpoint_uri)
_w3 = Web3(
AsyncHTTPProvider(endpoint_uri),
middlewares=[
async_buffered_gas_estimate_middleware,
async_gas_price_strategy_middleware,
async_validation_middleware,
],
modules={
"eth": AsyncEth,
"async_net": AsyncNet,
"geth": (
Geth,
{
"txpool": (AsyncGethTxPool,),
"personal": (AsyncGethPersonal,),
"admin": (AsyncGethAdmin,),
},
),
},
)
_w3 = Web3(AsyncHTTPProvider(endpoint_uri))
return _w3


Expand All @@ -165,22 +129,22 @@ class TestGoEthereumAsyncAdminModuleTest(GoEthereumAsyncAdminModuleTest):
@pytest.mark.xfail(
reason="running geth with the --nodiscover flag doesn't allow peer addition"
)
async def test_admin_peers(self, w3: "Web3") -> None:
await super().test_admin_peers(w3)
async def test_admin_peers(self, async_w3: "Web3") -> None:
await super().test_admin_peers(async_w3)

@pytest.mark.asyncio
async def test_admin_start_stop_http(self, w3: "Web3") -> None:
async def test_admin_start_stop_http(self, async_w3: "Web3") -> None:
# This test causes all tests after it to fail on CI if it's allowed to run
pytest.xfail(
reason="Only one HTTP endpoint is allowed to be active at any time"
)
await super().test_admin_start_stop_http(w3)
await super().test_admin_start_stop_http(async_w3)

@pytest.mark.asyncio
async def test_admin_start_stop_ws(self, w3: "Web3") -> None:
async def test_admin_start_stop_ws(self, async_w3: "Web3") -> None:
# This test causes all tests after it to fail on CI if it's allowed to run
pytest.xfail(reason="Only one WS endpoint is allowed to be active at any time")
await super().test_admin_start_stop_ws(w3)
await super().test_admin_start_stop_ws(async_w3)


class TestGoEthereumAsyncNetModuleTest(GoEthereumAsyncNetModuleTest):
Expand Down
4 changes: 3 additions & 1 deletion web3/_utils/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,11 +208,13 @@ def encode_abi(
)

normalizers = [
abi_ens_resolver(w3),
abi_address_to_hex,
abi_bytes_to_bytes,
abi_string_to_text,
]
if not w3.eth.is_async:
normalizers.append(abi_ens_resolver(w3))

normalized_arguments = map_abi_data(
normalizers,
argument_types,
Expand Down
6 changes: 3 additions & 3 deletions web3/_utils/module_testing/net_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ def test_net_peer_count(self, w3: "Web3") -> None:
class AsyncNetModuleTest:
@pytest.mark.asyncio
async def test_net_version(self, async_w3: "Web3") -> None:
version = await async_w3.async_net.version
version = await async_w3.net.version # type: ignore

assert is_string(version)
assert version.isdigit()

@pytest.mark.asyncio
async def test_net_listening(self, async_w3: "Web3") -> None:
listening = await async_w3.async_net.listening
listening = await async_w3.net.listening # type: ignore

assert is_boolean(listening)

@pytest.mark.asyncio
async def test_net_peer_count(self, async_w3: "Web3") -> None:
peer_count = await async_w3.async_net.peer_count
peer_count = await async_w3.net.peer_count # type: ignore

assert is_integer(peer_count)
9 changes: 7 additions & 2 deletions web3/_utils/normalizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,11 @@ def abi_address_to_hex(


@curry
def abi_ens_resolver(w3: "Web3", type_str: TypeStr, val: Any) -> Tuple[TypeStr, Any]:
def abi_ens_resolver(
w3: "Web3",
type_str: TypeStr,
val: Any,
) -> Tuple[TypeStr, Any]:
if type_str == "address" and is_ens_name(val):
if w3 is None:
raise InvalidAddress(
Expand All @@ -214,11 +218,12 @@ def abi_ens_resolver(w3: "Web3", type_str: TypeStr, val: Any) -> Tuple[TypeStr,
)

_ens = cast(ENS, w3.ens)
net_version = int(w3.net.version) if hasattr(w3, "net") else None
if _ens is None:
raise InvalidAddress(
f"Could not look up name {val!r} because ENS is" " set to None"
)
elif int(w3.net.version) != 1 and not isinstance(_ens, StaticENS):
elif net_version != 1 and not isinstance(_ens, StaticENS):
raise InvalidAddress(
f"Could not look up name {val!r} because web3 is"
" not connected to mainnet"
Expand Down
26 changes: 24 additions & 2 deletions web3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,13 @@
abi_ens_resolver,
)
from web3.eth import (
AsyncEth,
Eth,
)
from web3.geth import (
AsyncGethAdmin,
AsyncGethPersonal,
AsyncGethTxPool,
Geth,
GethAdmin,
GethMiner,
Expand Down Expand Up @@ -126,6 +130,21 @@
from web3._utils.empty import Empty # noqa: F401


def get_async_default_modules() -> Dict[str, Union[Type[Module], Sequence[Any]]]:
return {
"eth": AsyncEth,
"net": AsyncNet,
"geth": (
Geth,
{
"admin": AsyncGethAdmin,
"personal": AsyncGethPersonal,
"txpool": AsyncGethTxPool,
},
),
}


def get_default_modules() -> Dict[str, Union[Type[Module], Sequence[Any]]]:
return {
"eth": Eth,
Expand Down Expand Up @@ -219,7 +238,6 @@ def to_checksum_address(value: Union[AnyAddress, str, bytes]) -> ChecksumAddress
eth: Eth
geth: Geth
net: Net
async_net: AsyncNet

def __init__(
self,
Expand All @@ -237,7 +255,11 @@ def __init__(
self.codec = ABICodec(build_default_registry())

if modules is None:
modules = get_default_modules()
modules = (
get_async_default_modules()
if provider and provider.is_async
else get_default_modules()
)

self.attach_modules(modules)

Expand Down
Loading