Skip to content

Commit eed1ae4

Browse files
authored
Added caching to the async session in request.py and AsyncHTTPProvider (#2254)
* Added caching to the async session in request.py and AsyncHTTPProvider
1 parent 313d919 commit eed1ae4

File tree

6 files changed

+149
-7
lines changed

6 files changed

+149
-7
lines changed

docs/providers.rst

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,11 @@ AsyncHTTPProvider
379379
be omitted from the URI.
380380
* ``request_kwargs`` should be a dictionary of keyword arguments which
381381
will be passed onto each http/https POST request made to your node.
382+
* the ``cache_async_session()`` method allows you to use your own ``aiohttp.ClientSession`` object. This is an async method and not part of the constructor
382383

383384
.. code-block:: python
384385
386+
>>> from aiohttp import ClientSession
385387
>>> from web3 import Web3, AsyncHTTPProvider
386388
>>> from web3.eth import AsyncEth
387389
>>> from web3.net import AsyncNet
@@ -396,7 +398,10 @@ AsyncHTTPProvider
396398
... 'personal': (AsyncGethPersonal,),
397399
... 'admin' : (AsyncGethAdmin,)})
398400
... },
399-
... middlewares=[]) # See supported middleware section below for middleware options
401+
... middlewares=[] # See supported middleware section below for middleware options
402+
... )
403+
>>> custom_session = ClientSession() # If you want to pass in your own session
404+
>>> await w3.provider.cache_async_session(custom_session) # This method is an async method so it needs to be handled accordingly
400405
401406
Under the hood, the ``AsyncHTTPProvider`` uses the python
402407
`aiohttp <https://docs.aiohttp.org/en/stable/>`_ library for making requests.

newsfragments/2016.feature.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Added session caching to the AsyncHTTPProvider
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
2+
import pytest
3+
4+
from aiohttp import (
5+
ClientSession,
6+
)
7+
8+
from web3._utils import (
9+
request,
10+
)
11+
from web3.providers.async_rpc import (
12+
AsyncHTTPProvider,
13+
)
14+
15+
16+
@pytest.mark.asyncio
17+
async def test_user_provided_session() -> None:
18+
19+
session = ClientSession()
20+
provider = AsyncHTTPProvider(endpoint_uri="http://mynode.local:8545")
21+
await provider.cache_async_session(session)
22+
assert len(request._async_session_cache) == 1

tests/core/utilities/test_request.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
import pytest
2+
3+
from aiohttp import (
4+
ClientSession,
5+
)
16
from requests import (
27
Session,
38
adapters,
@@ -10,6 +15,9 @@
1015
from web3._utils import (
1116
request,
1217
)
18+
from web3._utils.request import (
19+
SessionCache,
20+
)
1321

1422

1523
class MockedResponse:
@@ -80,3 +88,47 @@ def test_precached_session(mocker):
8088
assert isinstance(adapter, HTTPAdapter)
8189
assert adapter._pool_connections == 100
8290
assert adapter._pool_maxsize == 100
91+
92+
93+
@pytest.mark.asyncio
94+
async def test_async_precached_session(mocker):
95+
# Add a session
96+
session = ClientSession()
97+
await request.cache_async_session(URI, session)
98+
assert len(request._async_session_cache) == 1
99+
100+
# Make sure the session isn't duplicated
101+
await request.cache_async_session(URI, session)
102+
assert len(request._async_session_cache) == 1
103+
104+
# Make sure a request with a different URI adds another cached session
105+
await request.cache_async_session(f"{URI}/test", session)
106+
assert len(request._async_session_cache) == 2
107+
108+
109+
def test_cache_session_class():
110+
111+
cache = SessionCache(2)
112+
evicted_items = cache.cache("1", "Hello1")
113+
assert cache.get_cache_entry("1") == "Hello1"
114+
assert evicted_items is None
115+
116+
evicted_items = cache.cache("2", "Hello2")
117+
assert cache.get_cache_entry("2") == "Hello2"
118+
assert evicted_items is None
119+
120+
# Changing what is stored at a given cache key should not cause the
121+
# anything to be evicted
122+
evicted_items = cache.cache("1", "HelloChanged")
123+
assert cache.get_cache_entry("1") == "HelloChanged"
124+
assert evicted_items is None
125+
126+
evicted_items = cache.cache("3", "Hello3")
127+
assert "2" in cache
128+
assert "3" in cache
129+
assert "1" not in cache
130+
131+
with pytest.raises(KeyError):
132+
# This should throw a KeyError since the cache size was 2 and 3 were inserted
133+
# the first inserted cached item was removed and returned in evicted items
134+
cache.get_cache_entry("1")

web3/_utils/request.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
from collections import (
2+
OrderedDict,
3+
)
14
import os
5+
import threading
26
from typing import (
37
Any,
8+
Dict,
49
)
510

611
from aiohttp import (
@@ -18,6 +23,37 @@
1823
)
1924

2025

26+
class SessionCache:
27+
28+
def __init__(self, size: int):
29+
self._size = size
30+
self._data: OrderedDict[str, Any] = OrderedDict()
31+
32+
def cache(self, key: str, value: Any) -> Dict[str, Any]:
33+
evicted_items = None
34+
# If the key is already in the OrderedDict just update it
35+
# and don't evict any values. Ideally, we could still check to see
36+
# if there are too many items in the OrderedDict but that may rearrange
37+
# the order it should be unlikely that the size could grow over the limit
38+
if key not in self._data:
39+
while len(self._data) >= self._size:
40+
if evicted_items is None:
41+
evicted_items = {}
42+
k, v = self._data.popitem(last=False)
43+
evicted_items[k] = v
44+
self._data[key] = value
45+
return evicted_items
46+
47+
def get_cache_entry(self, key: str) -> Any:
48+
return self._data[key]
49+
50+
def __contains__(self, item: str) -> bool:
51+
return item in self._data
52+
53+
def __len__(self) -> int:
54+
return len(self._data)
55+
56+
2157
def get_default_http_endpoint() -> URI:
2258
return URI(os.environ.get('WEB3_HTTP_PROVIDER_URI', 'http://localhost:8545'))
2359

@@ -27,11 +63,22 @@ def cache_session(endpoint_uri: URI, session: requests.Session) -> None:
2763
_session_cache[cache_key] = session
2864

2965

66+
async def cache_async_session(endpoint_uri: URI, session: ClientSession) -> None:
67+
cache_key = generate_cache_key(endpoint_uri)
68+
with _async_session_cache_lock:
69+
evicted_items = _async_session_cache.cache(cache_key, session)
70+
if evicted_items is not None:
71+
for key, session in evicted_items.items():
72+
await session.close()
73+
74+
3075
def _remove_session(key: str, session: requests.Session) -> None:
3176
session.close()
3277

3378

3479
_session_cache = lru.LRU(8, callback=_remove_session)
80+
_async_session_cache_lock = threading.Lock()
81+
_async_session_cache = SessionCache(size=8)
3582

3683

3784
def _get_session(endpoint_uri: URI) -> requests.Session:
@@ -41,6 +88,13 @@ def _get_session(endpoint_uri: URI) -> requests.Session:
4188
return _session_cache[cache_key]
4289

4390

91+
async def _get_async_session(endpoint_uri: URI) -> ClientSession:
92+
cache_key = generate_cache_key(endpoint_uri)
93+
if cache_key not in _async_session_cache:
94+
await cache_async_session(endpoint_uri, ClientSession(raise_for_status=True))
95+
return _async_session_cache.get_cache_entry(cache_key)
96+
97+
4498
def make_post_request(endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any) -> bytes:
4599
kwargs.setdefault('timeout', 10)
46100
session = _get_session(endpoint_uri)
@@ -55,9 +109,10 @@ async def async_make_post_request(
55109
endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any
56110
) -> bytes:
57111
kwargs.setdefault('timeout', ClientTimeout(10))
58-
async with ClientSession(raise_for_status=True) as session:
59-
async with session.post(endpoint_uri,
60-
data=data,
61-
*args,
62-
**kwargs) as response:
63-
return await response.read()
112+
# https://github.com/ethereum/go-ethereum/issues/17069
113+
session = await _get_async_session(endpoint_uri)
114+
async with session.post(endpoint_uri,
115+
data=data,
116+
*args,
117+
**kwargs) as response:
118+
return await response.read()

web3/providers/async_rpc.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
Union,
99
)
1010

11+
from aiohttp import (
12+
ClientSession,
13+
)
1114
from eth_typing import (
1215
URI,
1316
)
@@ -20,6 +23,7 @@
2023
)
2124
from web3._utils.request import (
2225
async_make_post_request,
26+
cache_async_session as _cache_async_session,
2327
get_default_http_endpoint,
2428
)
2529
from web3.types import (
@@ -50,6 +54,9 @@ def __init__(
5054

5155
super().__init__()
5256

57+
async def cache_async_session(self, session: ClientSession) -> None:
58+
await _cache_async_session(self.endpoint_uri, session)
59+
5360
def __str__(self) -> str:
5461
return "RPC connection {0}".format(self.endpoint_uri)
5562

0 commit comments

Comments
 (0)