From 34d782edfbdbd0948b15f3fbe1144e2c27314156 Mon Sep 17 00:00:00 2001 From: AlwaysData Date: Thu, 24 Feb 2022 18:10:40 -0500 Subject: [PATCH] Added caching to the async session in request.py and AsyncHTTPProvider (#2254) * Added caching to the async session in request.py and AsyncHTTPProvider --- docs/providers.rst | 7 +- newsfragments/2016.feature.rst | 1 + .../providers/test_async_http_provider.py | 22 ++++++ tests/core/utilities/test_request.py | 52 ++++++++++++++ web3/_utils/request.py | 67 +++++++++++++++++-- web3/providers/async_rpc.py | 7 ++ 6 files changed, 149 insertions(+), 7 deletions(-) create mode 100644 newsfragments/2016.feature.rst create mode 100644 tests/core/providers/test_async_http_provider.py diff --git a/docs/providers.rst b/docs/providers.rst index 90d6439857..cb1395ecd7 100644 --- a/docs/providers.rst +++ b/docs/providers.rst @@ -379,9 +379,11 @@ AsyncHTTPProvider be omitted from the URI. * ``request_kwargs`` should be a dictionary of keyword arguments which will be passed onto each http/https POST request made to your node. + * the ``cache_async_session()`` method allows you to use your own ``aiohttp.ClientSession`` object. This is an async method and not part of the constructor .. code-block:: python + >>> from aiohttp import ClientSession >>> from web3 import Web3, AsyncHTTPProvider >>> from web3.eth import AsyncEth >>> from web3.net import AsyncNet @@ -396,7 +398,10 @@ AsyncHTTPProvider ... 'personal': (AsyncGethPersonal,), ... 'admin' : (AsyncGethAdmin,)}) ... }, - ... middlewares=[]) # See supported middleware section below for middleware options + ... middlewares=[] # See supported middleware section below for middleware options + ... ) + >>> custom_session = ClientSession() # If you want to pass in your own session + >>> await w3.provider.cache_async_session(custom_session) # This method is an async method so it needs to be handled accordingly Under the hood, the ``AsyncHTTPProvider`` uses the python `aiohttp `_ library for making requests. diff --git a/newsfragments/2016.feature.rst b/newsfragments/2016.feature.rst new file mode 100644 index 0000000000..fb34134740 --- /dev/null +++ b/newsfragments/2016.feature.rst @@ -0,0 +1 @@ +Added session caching to the AsyncHTTPProvider \ No newline at end of file diff --git a/tests/core/providers/test_async_http_provider.py b/tests/core/providers/test_async_http_provider.py new file mode 100644 index 0000000000..bceea41e54 --- /dev/null +++ b/tests/core/providers/test_async_http_provider.py @@ -0,0 +1,22 @@ + +import pytest + +from aiohttp import ( + ClientSession, +) + +from web3._utils import ( + request, +) +from web3.providers.async_rpc import ( + AsyncHTTPProvider, +) + + +@pytest.mark.asyncio +async def test_user_provided_session() -> None: + + session = ClientSession() + provider = AsyncHTTPProvider(endpoint_uri="http://mynode.local:8545") + await provider.cache_async_session(session) + assert len(request._async_session_cache) == 1 diff --git a/tests/core/utilities/test_request.py b/tests/core/utilities/test_request.py index 6947da4eb9..32d4cf9baa 100644 --- a/tests/core/utilities/test_request.py +++ b/tests/core/utilities/test_request.py @@ -1,3 +1,8 @@ +import pytest + +from aiohttp import ( + ClientSession, +) from requests import ( Session, adapters, @@ -10,6 +15,9 @@ from web3._utils import ( request, ) +from web3._utils.request import ( + SessionCache, +) class MockedResponse: @@ -80,3 +88,47 @@ def test_precached_session(mocker): assert isinstance(adapter, HTTPAdapter) assert adapter._pool_connections == 100 assert adapter._pool_maxsize == 100 + + +@pytest.mark.asyncio +async def test_async_precached_session(mocker): + # Add a session + session = ClientSession() + await request.cache_async_session(URI, session) + assert len(request._async_session_cache) == 1 + + # Make sure the session isn't duplicated + await request.cache_async_session(URI, session) + assert len(request._async_session_cache) == 1 + + # Make sure a request with a different URI adds another cached session + await request.cache_async_session(f"{URI}/test", session) + assert len(request._async_session_cache) == 2 + + +def test_cache_session_class(): + + cache = SessionCache(2) + evicted_items = cache.cache("1", "Hello1") + assert cache.get_cache_entry("1") == "Hello1" + assert evicted_items is None + + evicted_items = cache.cache("2", "Hello2") + assert cache.get_cache_entry("2") == "Hello2" + assert evicted_items is None + + # Changing what is stored at a given cache key should not cause the + # anything to be evicted + evicted_items = cache.cache("1", "HelloChanged") + assert cache.get_cache_entry("1") == "HelloChanged" + assert evicted_items is None + + evicted_items = cache.cache("3", "Hello3") + assert "2" in cache + assert "3" in cache + assert "1" not in cache + + with pytest.raises(KeyError): + # This should throw a KeyError since the cache size was 2 and 3 were inserted + # the first inserted cached item was removed and returned in evicted items + cache.get_cache_entry("1") diff --git a/web3/_utils/request.py b/web3/_utils/request.py index 0ddfc81354..fadafe5ed9 100644 --- a/web3/_utils/request.py +++ b/web3/_utils/request.py @@ -1,6 +1,11 @@ +from collections import ( + OrderedDict, +) import os +import threading from typing import ( Any, + Dict, ) from aiohttp import ( @@ -18,6 +23,37 @@ ) +class SessionCache: + + def __init__(self, size: int): + self._size = size + self._data: OrderedDict[str, Any] = OrderedDict() + + def cache(self, key: str, value: Any) -> Dict[str, Any]: + evicted_items = None + # If the key is already in the OrderedDict just update it + # and don't evict any values. Ideally, we could still check to see + # if there are too many items in the OrderedDict but that may rearrange + # the order it should be unlikely that the size could grow over the limit + if key not in self._data: + while len(self._data) >= self._size: + if evicted_items is None: + evicted_items = {} + k, v = self._data.popitem(last=False) + evicted_items[k] = v + self._data[key] = value + return evicted_items + + def get_cache_entry(self, key: str) -> Any: + return self._data[key] + + def __contains__(self, item: str) -> bool: + return item in self._data + + def __len__(self) -> int: + return len(self._data) + + def get_default_http_endpoint() -> URI: return URI(os.environ.get('WEB3_HTTP_PROVIDER_URI', 'http://localhost:8545')) @@ -27,11 +63,22 @@ def cache_session(endpoint_uri: URI, session: requests.Session) -> None: _session_cache[cache_key] = session +async def cache_async_session(endpoint_uri: URI, session: ClientSession) -> None: + cache_key = generate_cache_key(endpoint_uri) + with _async_session_cache_lock: + evicted_items = _async_session_cache.cache(cache_key, session) + if evicted_items is not None: + for key, session in evicted_items.items(): + await session.close() + + def _remove_session(key: str, session: requests.Session) -> None: session.close() _session_cache = lru.LRU(8, callback=_remove_session) +_async_session_cache_lock = threading.Lock() +_async_session_cache = SessionCache(size=8) def _get_session(endpoint_uri: URI) -> requests.Session: @@ -41,6 +88,13 @@ def _get_session(endpoint_uri: URI) -> requests.Session: return _session_cache[cache_key] +async def _get_async_session(endpoint_uri: URI) -> ClientSession: + cache_key = generate_cache_key(endpoint_uri) + if cache_key not in _async_session_cache: + await cache_async_session(endpoint_uri, ClientSession(raise_for_status=True)) + return _async_session_cache.get_cache_entry(cache_key) + + def make_post_request(endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any) -> bytes: kwargs.setdefault('timeout', 10) session = _get_session(endpoint_uri) @@ -55,9 +109,10 @@ async def async_make_post_request( endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any ) -> bytes: kwargs.setdefault('timeout', ClientTimeout(10)) - async with ClientSession(raise_for_status=True) as session: - async with session.post(endpoint_uri, - data=data, - *args, - **kwargs) as response: - return await response.read() + # https://github.com/ethereum/go-ethereum/issues/17069 + session = await _get_async_session(endpoint_uri) + async with session.post(endpoint_uri, + data=data, + *args, + **kwargs) as response: + return await response.read() diff --git a/web3/providers/async_rpc.py b/web3/providers/async_rpc.py index 82f945a4d9..d3229c71de 100644 --- a/web3/providers/async_rpc.py +++ b/web3/providers/async_rpc.py @@ -8,6 +8,9 @@ Union, ) +from aiohttp import ( + ClientSession, +) from eth_typing import ( URI, ) @@ -20,6 +23,7 @@ ) from web3._utils.request import ( async_make_post_request, + cache_async_session as _cache_async_session, get_default_http_endpoint, ) from web3.types import ( @@ -50,6 +54,9 @@ def __init__( super().__init__() + async def cache_async_session(self, session: ClientSession) -> None: + await _cache_async_session(self.endpoint_uri, session) + def __str__(self) -> str: return "RPC connection {0}".format(self.endpoint_uri)