diff --git a/newsfragments/2409.bugfix.rst b/newsfragments/2409.bugfix.rst new file mode 100644 index 0000000000..fe997cd6aa --- /dev/null +++ b/newsfragments/2409.bugfix.rst @@ -0,0 +1 @@ +Improve upon issues with session caching - better support for multithreading and make sure session eviction from cache does not happen prematurely. \ No newline at end of file diff --git a/tests/core/providers/test_async_http_provider.py b/tests/core/providers/test_async_http_provider.py index 111e24bb45..c5c3b6cc95 100644 --- a/tests/core/providers/test_async_http_provider.py +++ b/tests/core/providers/test_async_http_provider.py @@ -14,8 +14,8 @@ @pytest.mark.asyncio async def test_user_provided_session() -> None: - session = ClientSession() provider = AsyncHTTPProvider(endpoint_uri="http://mynode.local:8545") - await provider.cache_async_session(session) + cached_session = await provider.cache_async_session(session) assert len(request._async_session_cache) == 1 + assert cached_session == session diff --git a/tests/core/providers/test_http_provider.py b/tests/core/providers/test_http_provider.py index 247053d0d0..0e6e15b502 100644 --- a/tests/core/providers/test_http_provider.py +++ b/tests/core/providers/test_http_provider.py @@ -38,7 +38,7 @@ def test_user_provided_session(): w3 = Web3(provider) assert w3.manager.provider == provider - session = request.get_session(URI) + session = request.cache_and_return_session(URI) adapter = session.get_adapter(URI) assert isinstance(adapter, HTTPAdapter) assert adapter._pool_connections == 20 diff --git a/tests/core/utilities/test_request.py b/tests/core/utilities/test_request.py index a1f6011227..56c405c89b 100644 --- a/tests/core/utilities/test_request.py +++ b/tests/core/utilities/test_request.py @@ -1,8 +1,16 @@ +import asyncio +from concurrent.futures import ( + ThreadPoolExecutor, +) import pytest +import time from aiohttp import ( ClientSession, ) +from eth_typing import ( + URI, +) from requests import ( Session, adapters, @@ -15,8 +23,13 @@ from web3._utils import ( request, ) +from web3._utils.caching import ( + generate_cache_key, +) from web3._utils.request import ( SessionCache, + cache_and_return_async_session, + cache_and_return_session, ) @@ -34,7 +47,14 @@ def raise_for_status(self): pass -URI = "http://mynode.local:8545" +TEST_URI = URI("http://mynode.local:8545") +UNIQUE_URIS = [ + "https://www.test1.com", + "https://www.test2.com", + "https://www.test3.com", + "https://www.test4.com", + "https://www.test5.com", +] def check_adapters_mounted(session: Session): @@ -48,19 +68,23 @@ def test_make_post_request_no_args(mocker): # Submit a first request to create a session with default parameters assert len(request._session_cache) == 0 - response = request.make_post_request(URI, data=b"request") + response = request.make_post_request(TEST_URI, data=b"request") assert response == "content" assert len(request._session_cache) == 1 - session = request._session_cache.values()[0] - session.post.assert_called_once_with(URI, data=b"request", timeout=10) + cache_key = generate_cache_key(TEST_URI) + session = request._session_cache.get_cache_entry(cache_key) + session.post.assert_called_once_with(TEST_URI, data=b"request", timeout=10) # Ensure the adapter was created with default values check_adapters_mounted(session) - adapter = session.get_adapter(URI) + adapter = session.get_adapter(TEST_URI) assert isinstance(adapter, HTTPAdapter) assert adapter._pool_connections == DEFAULT_POOLSIZE assert adapter._pool_maxsize == DEFAULT_POOLSIZE + # clear cache + request._session_cache.clear() + def test_precached_session(mocker): mocker.patch("requests.Session.post", return_value=MockedResponse()) @@ -70,44 +94,30 @@ def test_precached_session(mocker): session = Session() session.mount("http://", adapter) session.mount("https://", adapter) - request.cache_session(URI, session) + request.cache_and_return_session(TEST_URI, session) # Submit a second request with different arguments assert len(request._session_cache) == 1 - response = request.make_post_request(URI, data=b"request", timeout=60) + response = request.make_post_request(TEST_URI, data=b"request", timeout=60) assert response == "content" assert len(request._session_cache) == 1 # Ensure the timeout was passed to the request - session = request.get_session(URI) - session.post.assert_called_once_with(URI, data=b"request", timeout=60) + session = request.cache_and_return_session(TEST_URI) + session.post.assert_called_once_with(TEST_URI, data=b"request", timeout=60) # Ensure the adapter parameters match those we specified check_adapters_mounted(session) - adapter = session.get_adapter(URI) + adapter = session.get_adapter(TEST_URI) assert isinstance(adapter, HTTPAdapter) assert adapter._pool_connections == 100 assert adapter._pool_maxsize == 100 - -@pytest.mark.asyncio -async def test_async_precached_session(mocker): - # Add a session - session = ClientSession() - await request.cache_async_session(URI, session) - assert len(request._async_session_cache) == 1 - - # Make sure the session isn't duplicated - await request.cache_async_session(URI, session) - assert len(request._async_session_cache) == 1 - - # Make sure a request with a different URI adds another cached session - await request.cache_async_session(f"{URI}/test", session) - assert len(request._async_session_cache) == 2 + # clear cache + request._session_cache.clear() def test_cache_session_class(): - cache = SessionCache(2) evicted_items = cache.cache("1", "Hello1") assert cache.get_cache_entry("1") == "Hello1" @@ -126,9 +136,129 @@ def test_cache_session_class(): evicted_items = cache.cache("3", "Hello3") assert "2" in cache assert "3" in cache + assert "1" not in cache + assert "1" in evicted_items with pytest.raises(KeyError): # This should throw a KeyError since the cache size was 2 and 3 were inserted # the first inserted cached item was removed and returned in evicted items cache.get_cache_entry("1") + + # clear cache + request._session_cache.clear() + + +def test_cache_does_not_close_session_before_a_call_when_multithreading(): + # save default values + session_cache_default = request._async_session_cache + timeout_default = request.DEFAULT_TIMEOUT + + # set cache size to 1 + set future session close thread time to 0.01s + request._session_cache = SessionCache(1) + _timeout_for_testing = 0.01 + request.DEFAULT_TIMEOUT = _timeout_for_testing + + def _simulate_call(uri): + _session = cache_and_return_session(uri) + + # simulate a call taking 0.01s to return a response + time.sleep(0.01) + return _session + + with ThreadPoolExecutor(max_workers=len(UNIQUE_URIS)) as exc: + all_sessions = [exc.submit(_simulate_call, uri) for uri in UNIQUE_URIS] + + # assert last session remains in cache, all others evicted + cache_data = request._session_cache._data + assert len(cache_data) == 1 + _key, cached_session = cache_data.popitem() + assert cached_session == all_sessions[-1].result() # result of the `Future` + + # -- teardown -- # + + # close the cached session before exiting test + cached_session.close() + + # reset default values + request._async_session_cache = session_cache_default + request.DEFAULT_TIMEOUT = timeout_default + + # clear cache + request._session_cache.clear() + + +# -- async -- # + + +@pytest.mark.asyncio +async def test_async_precached_session(): + # Add a session + session = ClientSession() + await request.cache_and_return_async_session(TEST_URI, session) + assert len(request._async_session_cache) == 1 + + # Make sure the session isn't duplicated + await request.cache_and_return_async_session(TEST_URI, session) + assert len(request._async_session_cache) == 1 + + # Make sure a request with a different URI adds another cached session + await request.cache_and_return_async_session(URI(f"{TEST_URI}/test"), session) + assert len(request._async_session_cache) == 2 + + # -- teardown -- # + + # appropriately close the cached sessions + [await session.close() for session in request._async_session_cache._data.values()] + + # clear cache + request._async_session_cache.clear() + + +@pytest.mark.asyncio +async def test_async_cache_does_not_close_session_before_a_call_when_multithreading(): + # save default values + session_cache_default = request._async_session_cache + timeout_default = request.DEFAULT_TIMEOUT + + # set cache size to 1 + set future session close thread time to 0.01s + request._async_session_cache = SessionCache(1) + _timeout_for_testing = 0.01 + request.DEFAULT_TIMEOUT = _timeout_for_testing + + async def cache_uri_and_return_session(uri): + _session = await cache_and_return_async_session(uri) + + # simulate a call taking 0.01s to return a response + await asyncio.sleep(0.01) + + assert not _session.closed + return _session + + tasks = [cache_uri_and_return_session(uri) for uri in UNIQUE_URIS] + + all_sessions = await asyncio.gather(*tasks) + assert len(all_sessions) == len(UNIQUE_URIS) + assert all(isinstance(s, ClientSession) for s in all_sessions) + + # last session remains in cache, all others evicted + cache_data = request._async_session_cache._data + assert len(cache_data) == 1 + _key, cached_session = cache_data.popitem() + assert cached_session == all_sessions[-1] + + # assert all evicted sessions were closed + await asyncio.sleep(_timeout_for_testing + 0.1) + assert all(session.closed for session in all_sessions[:-1]) + + # -- teardown -- # + + # appropriately close the cached session + await cached_session.close() + + # reset default values + request._async_session_cache = session_cache_default + request.DEFAULT_TIMEOUT = timeout_default + + # clear cache + request._async_session_cache.clear() diff --git a/web3/_utils/module_testing/module_testing_utils.py b/web3/_utils/module_testing/module_testing_utils.py index cbf86cbebb..9b5a6bab6e 100644 --- a/web3/_utils/module_testing/module_testing_utils.py +++ b/web3/_utils/module_testing/module_testing_utils.py @@ -26,8 +26,8 @@ Literal, ) from web3._utils.request import ( - get_async_session, - get_session, + cache_and_return_async_session, + cache_and_return_session, ) from web3.types import ( BlockData, @@ -103,7 +103,7 @@ def _mock_specific_request( return MockedResponse() # else, make a normal request (no mocking) - session = get_session(url_from_args) + session = cache_and_return_session(url_from_args) return session.request(method=http_method.upper(), url=url_from_args, **kwargs) monkeypatch.setattr( @@ -153,7 +153,7 @@ async def _mock_specific_request( return AsyncMockedResponse() # else, make a normal request (no mocking) - session = await get_async_session(url_from_args) + session = await cache_and_return_async_session(url_from_args) return await session.request( method=http_method.upper(), url=url_from_args, **kwargs ) diff --git a/web3/_utils/request.py b/web3/_utils/request.py index 242e20cb4b..5a3726af08 100644 --- a/web3/_utils/request.py +++ b/web3/_utils/request.py @@ -1,11 +1,20 @@ +import asyncio from collections import ( OrderedDict, ) +from concurrent.futures import ( + ThreadPoolExecutor, +) +import contextlib +import logging import os import threading from typing import ( Any, + AsyncGenerator, Dict, + List, + Optional, Union, ) @@ -17,13 +26,16 @@ from eth_typing import ( URI, ) -import lru import requests from web3._utils.caching import ( generate_cache_key, ) +logger = logging.getLogger(__name__) + +DEFAULT_TIMEOUT = 10 + class SessionCache: def __init__(self, size: int): @@ -48,6 +60,9 @@ def cache(self, key: str, value: Any) -> Dict[str, Any]: def get_cache_entry(self, key: str) -> Any: return self._data[key] + def clear(self) -> None: + self._data.clear() + def __contains__(self, item: str) -> bool: return item in self._data @@ -59,30 +74,46 @@ def get_default_http_endpoint() -> URI: return URI(os.environ.get("WEB3_HTTP_PROVIDER_URI", "http://localhost:8545")) -def cache_session(endpoint_uri: URI, session: requests.Session) -> None: - cache_key = generate_cache_key(endpoint_uri) - _session_cache[cache_key] = session +_session_cache = SessionCache(size=20) +_session_cache_lock = threading.Lock() -def _remove_session(_key: str, session: requests.Session) -> None: - session.close() +def cache_and_return_session( + endpoint_uri: URI, session: requests.Session = None +) -> requests.Session: + cache_key = generate_cache_key(endpoint_uri) + evicted_items = None + with _session_cache_lock: + if cache_key not in _session_cache: + if session is None: + session = requests.Session() -_session_cache = lru.LRU(8, callback=_remove_session) + evicted_items = _session_cache.cache(cache_key, session) + logger.debug(f"Session cached: {endpoint_uri}, {session}") + cached_session = _session_cache.get_cache_entry(cache_key) -def get_session(endpoint_uri: URI) -> requests.Session: - cache_key = generate_cache_key(endpoint_uri) - if cache_key not in _session_cache: - _session_cache[cache_key] = requests.Session() - return _session_cache[cache_key] + if evicted_items is not None: + evicted_sessions = evicted_items.values() + for evicted_session in evicted_sessions: + logger.debug( + f"Session cache full. Session evicted from cache: {evicted_session}", + ) + threading.Timer( + DEFAULT_TIMEOUT + 0.1, + _close_evicted_sessions, + args=[evicted_sessions], + ).start() + + return cached_session def get_response_from_get_request( endpoint_uri: URI, *args: Any, **kwargs: Any ) -> requests.Response: - kwargs.setdefault("timeout", 10) - session = get_session(endpoint_uri) + kwargs.setdefault("timeout", DEFAULT_TIMEOUT) + session = cache_and_return_session(endpoint_uri) response = session.get(endpoint_uri, *args, **kwargs) return response @@ -90,8 +121,8 @@ def get_response_from_get_request( def get_response_from_post_request( endpoint_uri: URI, *args: Any, **kwargs: Any ) -> requests.Response: - kwargs.setdefault("timeout", 10) - session = get_session(endpoint_uri) + kwargs.setdefault("timeout", DEFAULT_TIMEOUT) + session = cache_and_return_session(endpoint_uri) response = session.post(endpoint_uri, *args, **kwargs) return response @@ -104,33 +135,77 @@ def make_post_request( return response.content +def _close_evicted_sessions(evicted_sessions: List[requests.Session]) -> None: + for evicted_session in evicted_sessions: + evicted_session.close() + logger.debug(f"Closed evicted session: {evicted_session}") + + # --- async --- # -_async_session_cache_lock = threading.Lock() + _async_session_cache = SessionCache(size=20) +_async_session_cache_lock = threading.Lock() +_pool = ThreadPoolExecutor(max_workers=1) -async def cache_async_session(endpoint_uri: URI, session: ClientSession) -> None: +async def cache_and_return_async_session( + endpoint_uri: URI, + session: Optional[ClientSession] = None, +) -> ClientSession: cache_key = generate_cache_key(endpoint_uri) - with _async_session_cache_lock: - evicted_items = _async_session_cache.cache(cache_key, session) - if evicted_items is not None: - for key, session in evicted_items.items(): - await session.close() - -async def get_async_session(endpoint_uri: URI) -> ClientSession: - cache_key = generate_cache_key(endpoint_uri) - if cache_key not in _async_session_cache: - await cache_async_session(endpoint_uri, ClientSession(raise_for_status=True)) - return _async_session_cache.get_cache_entry(cache_key) + evicted_items = None + async with async_lock(_async_session_cache_lock): + if cache_key not in _async_session_cache: + if session is None: + session = ClientSession(raise_for_status=True) + + evicted_items = _async_session_cache.cache(cache_key, session) + logger.debug(f"Async session cached: {endpoint_uri}, {session}") + + cached_session = _async_session_cache.get_cache_entry(cache_key) + + if evicted_items is not None: + # At this point the evicted sessions are already popped out of the cache and + # just stored in the `evicted_sessions` dict. So we can kick off a future task + # to close them and it should be safe to pop out of the lock here. + evicted_sessions = evicted_items.values() + for evicted_session in evicted_sessions: + logger.debug( + "Async session cache full. Session evicted from cache: " + f"{evicted_session}", + ) + # Kick off a future task, in a separate thread, to close the evicted + # sessions. In the case that the cache filled very quickly and some + # sessions have been evicted before their original request has been made, + # we set the timer to a bit more than the `DEFAULT_TIMEOUT` for a call. This + # should make it so that any call from an evicted session can still be made + # before the session is closed. + threading.Timer( + DEFAULT_TIMEOUT + 0.1, + _async_close_evicted_sessions, + args=[evicted_sessions], + ).start() + + return cached_session + + +@contextlib.asynccontextmanager +async def async_lock(lock: threading.Lock) -> AsyncGenerator[None, None]: + loop = asyncio.get_event_loop() + await loop.run_in_executor(_pool, lock.acquire) + try: + yield + finally: + lock.release() async def async_get_response_from_get_request( endpoint_uri: URI, *args: Any, **kwargs: Any ) -> ClientResponse: - kwargs.setdefault("timeout", ClientTimeout(10)) - session = await get_async_session(endpoint_uri) + kwargs.setdefault("timeout", ClientTimeout(DEFAULT_TIMEOUT)) + session = await cache_and_return_async_session(endpoint_uri) response = await session.get(endpoint_uri, *args, **kwargs) return response @@ -138,9 +213,10 @@ async def async_get_response_from_get_request( async def async_get_response_from_post_request( endpoint_uri: URI, *args: Any, **kwargs: Any ) -> ClientResponse: - kwargs.setdefault("timeout", ClientTimeout(10)) - session = await get_async_session(endpoint_uri) + kwargs.setdefault("timeout", ClientTimeout(DEFAULT_TIMEOUT)) + session = await cache_and_return_async_session(endpoint_uri) response = await session.post(endpoint_uri, *args, **kwargs) + return response @@ -157,3 +233,17 @@ async def async_get_json_from_client_response( response: ClientResponse, ) -> Dict[str, Any]: return await response.json() + + +def _async_close_evicted_sessions(evicted_sessions: List[ClientSession]) -> None: + loop = asyncio.new_event_loop() + + for evicted_session in evicted_sessions: + loop.run_until_complete(evicted_session.close()) + logger.debug(f"Closed evicted async session: {evicted_session}") + + if any(not evicted_session.closed for evicted_session in evicted_sessions): + logger.warning( + f"Some evicted async sessions were not properly closed: {evicted_sessions}" + ) + loop.close() diff --git a/web3/providers/async_rpc.py b/web3/providers/async_rpc.py index aa92f046ca..258f585bcc 100644 --- a/web3/providers/async_rpc.py +++ b/web3/providers/async_rpc.py @@ -23,7 +23,7 @@ ) from web3._utils.request import ( async_make_post_request, - cache_async_session as _cache_async_session, + cache_and_return_async_session as _cache_and_return_async_session, get_default_http_endpoint, ) from web3.types import ( @@ -55,8 +55,8 @@ def __init__( super().__init__() - async def cache_async_session(self, session: ClientSession) -> None: - await _cache_async_session(self.endpoint_uri, session) + async def cache_async_session(self, session: ClientSession) -> ClientSession: + return await _cache_and_return_async_session(self.endpoint_uri, session) def __str__(self) -> str: return f"RPC connection {self.endpoint_uri}" diff --git a/web3/providers/rpc.py b/web3/providers/rpc.py index aba4138e00..bffa69392a 100644 --- a/web3/providers/rpc.py +++ b/web3/providers/rpc.py @@ -19,7 +19,7 @@ construct_user_agent, ) from web3._utils.request import ( - cache_session, + cache_and_return_session, get_default_http_endpoint, make_post_request, ) @@ -62,7 +62,7 @@ def __init__( self._request_kwargs = request_kwargs or {} if session: - cache_session(self.endpoint_uri, session) + cache_and_return_session(self.endpoint_uri, session) super().__init__()