Skip to content

Added caching to the async session in request.py and AsyncHTTPProvide… #2363

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion docs/providers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,11 @@ AsyncHTTPProvider
be omitted from the URI.
* ``request_kwargs`` should be a dictionary of keyword arguments which
will be passed onto each http/https POST request made to your node.
* the ``cache_async_session()`` method allows you to use your own ``aiohttp.ClientSession`` object. This is an async method and not part of the constructor

.. code-block:: python

>>> from aiohttp import ClientSession
>>> from web3 import Web3, AsyncHTTPProvider
>>> from web3.eth import AsyncEth
>>> from web3.net import AsyncNet
Expand All @@ -396,7 +398,10 @@ AsyncHTTPProvider
... 'personal': (AsyncGethPersonal,),
... 'admin' : (AsyncGethAdmin,)})
... },
... middlewares=[]) # See supported middleware section below for middleware options
... middlewares=[] # See supported middleware section below for middleware options
... )
>>> custom_session = ClientSession() # If you want to pass in your own session
>>> await w3.provider.cache_async_session(custom_session) # This method is an async method so it needs to be handled accordingly

Under the hood, the ``AsyncHTTPProvider`` uses the python
`aiohttp <https://docs.aiohttp.org/en/stable/>`_ library for making requests.
Expand Down
1 change: 1 addition & 0 deletions newsfragments/2016.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Added session caching to the AsyncHTTPProvider
22 changes: 22 additions & 0 deletions tests/core/providers/test_async_http_provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@

import pytest

from aiohttp import (
ClientSession,
)

from web3._utils import (
request,
)
from web3.providers.async_rpc import (
AsyncHTTPProvider,
)


@pytest.mark.asyncio
async def test_user_provided_session() -> None:

session = ClientSession()
provider = AsyncHTTPProvider(endpoint_uri="http://mynode.local:8545")
await provider.cache_async_session(session)
assert len(request._async_session_cache) == 1
52 changes: 52 additions & 0 deletions tests/core/utilities/test_request.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
import pytest

from aiohttp import (
ClientSession,
)
from requests import (
Session,
adapters,
Expand All @@ -10,6 +15,9 @@
from web3._utils import (
request,
)
from web3._utils.request import (
SessionCache,
)


class MockedResponse:
Expand Down Expand Up @@ -80,3 +88,47 @@ def test_precached_session(mocker):
assert isinstance(adapter, HTTPAdapter)
assert adapter._pool_connections == 100
assert adapter._pool_maxsize == 100


@pytest.mark.asyncio
async def test_async_precached_session(mocker):
# Add a session
session = ClientSession()
await request.cache_async_session(URI, session)
assert len(request._async_session_cache) == 1

# Make sure the session isn't duplicated
await request.cache_async_session(URI, session)
assert len(request._async_session_cache) == 1

# Make sure a request with a different URI adds another cached session
await request.cache_async_session(f"{URI}/test", session)
assert len(request._async_session_cache) == 2


def test_cache_session_class():

cache = SessionCache(2)
evicted_items = cache.cache("1", "Hello1")
assert cache.get_cache_entry("1") == "Hello1"
assert evicted_items is None

evicted_items = cache.cache("2", "Hello2")
assert cache.get_cache_entry("2") == "Hello2"
assert evicted_items is None

# Changing what is stored at a given cache key should not cause the
# anything to be evicted
evicted_items = cache.cache("1", "HelloChanged")
assert cache.get_cache_entry("1") == "HelloChanged"
assert evicted_items is None

evicted_items = cache.cache("3", "Hello3")
assert "2" in cache
assert "3" in cache
assert "1" not in cache

with pytest.raises(KeyError):
# This should throw a KeyError since the cache size was 2 and 3 were inserted
# the first inserted cached item was removed and returned in evicted items
cache.get_cache_entry("1")
67 changes: 61 additions & 6 deletions web3/_utils/request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from collections import (
OrderedDict,
)
import os
import threading
from typing import (
Any,
Dict,
)

from aiohttp import (
Expand All @@ -18,6 +23,37 @@
)


class SessionCache:

def __init__(self, size: int):
self._size = size
self._data: OrderedDict[str, Any] = OrderedDict()

def cache(self, key: str, value: Any) -> Dict[str, Any]:
evicted_items = None
# If the key is already in the OrderedDict just update it
# and don't evict any values. Ideally, we could still check to see
# if there are too many items in the OrderedDict but that may rearrange
# the order it should be unlikely that the size could grow over the limit
if key not in self._data:
while len(self._data) >= self._size:
if evicted_items is None:
evicted_items = {}
k, v = self._data.popitem(last=False)
evicted_items[k] = v
self._data[key] = value
return evicted_items

def get_cache_entry(self, key: str) -> Any:
return self._data[key]

def __contains__(self, item: str) -> bool:
return item in self._data

def __len__(self) -> int:
return len(self._data)


def get_default_http_endpoint() -> URI:
return URI(os.environ.get('WEB3_HTTP_PROVIDER_URI', 'http://localhost:8545'))

Expand All @@ -27,11 +63,22 @@ def cache_session(endpoint_uri: URI, session: requests.Session) -> None:
_session_cache[cache_key] = session


async def cache_async_session(endpoint_uri: URI, session: ClientSession) -> None:
cache_key = generate_cache_key(endpoint_uri)
with _async_session_cache_lock:
evicted_items = _async_session_cache.cache(cache_key, session)
if evicted_items is not None:
for key, session in evicted_items.items():
await session.close()


def _remove_session(key: str, session: requests.Session) -> None:
session.close()


_session_cache = lru.LRU(8, callback=_remove_session)
_async_session_cache_lock = threading.Lock()
_async_session_cache = SessionCache(size=8)


def _get_session(endpoint_uri: URI) -> requests.Session:
Expand All @@ -41,6 +88,13 @@ def _get_session(endpoint_uri: URI) -> requests.Session:
return _session_cache[cache_key]


async def _get_async_session(endpoint_uri: URI) -> ClientSession:
cache_key = generate_cache_key(endpoint_uri)
if cache_key not in _async_session_cache:
await cache_async_session(endpoint_uri, ClientSession(raise_for_status=True))
return _async_session_cache.get_cache_entry(cache_key)


def make_post_request(endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any) -> bytes:
kwargs.setdefault('timeout', 10)
session = _get_session(endpoint_uri)
Expand All @@ -55,9 +109,10 @@ async def async_make_post_request(
endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any
) -> bytes:
kwargs.setdefault('timeout', ClientTimeout(10))
async with ClientSession(raise_for_status=True) as session:
async with session.post(endpoint_uri,
data=data,
*args,
**kwargs) as response:
return await response.read()
# https://github.com/ethereum/go-ethereum/issues/17069
session = await _get_async_session(endpoint_uri)
async with session.post(endpoint_uri,
data=data,
*args,
**kwargs) as response:
return await response.read()
7 changes: 7 additions & 0 deletions web3/providers/async_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
Union,
)

from aiohttp import (
ClientSession,
)
from eth_typing import (
URI,
)
Expand All @@ -20,6 +23,7 @@
)
from web3._utils.request import (
async_make_post_request,
cache_async_session as _cache_async_session,
get_default_http_endpoint,
)
from web3.types import (
Expand Down Expand Up @@ -50,6 +54,9 @@ def __init__(

super().__init__()

async def cache_async_session(self, session: ClientSession) -> None:
await _cache_async_session(self.endpoint_uri, session)

def __str__(self) -> str:
return "RPC connection {0}".format(self.endpoint_uri)

Expand Down