Skip to content

Adding locks when manipulating the _session_cache object in case of multi-threading #1982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 62 additions & 7 deletions web3/_utils/request.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import os
from threading import (
RLock,
)
from typing import (
Any,
)

from aiohttp import (
ClientSession,
ClientTimeout,
)
from eth_typing import (
URI,
)
Expand All @@ -12,23 +19,59 @@
generate_cache_key,
)

_lock = RLock()


def _acquireLock() -> None:
"""
Acquire the module-level lock for serializing access to shared cache.

This should be released with _releaseLock().
"""
if _lock:
_lock.acquire()


def _releaseLock() -> None:
"""
Release the module-level lock acquired by calling _acquireLock().
"""
if _lock:
_lock.release()


def get_default_http_endpoint() -> URI:
return URI(os.environ.get('WEB3_HTTP_PROVIDER_URI', 'http://localhost:8545'))


def cache_session(endpoint_uri: URI, session: requests.Session) -> None:
cache_key = generate_cache_key(endpoint_uri)
_session_cache[cache_key] = session
_acquireLock()
try:
cache_key = generate_cache_key(endpoint_uri)
_session_cache[cache_key] = session
finally:
_releaseLock()


def _remove_session(key: str, session: requests.Session) -> None:
session.close()
_acquireLock()
try:
session.close()
finally:
_releaseLock()


_session_cache = lru.LRU(8, callback=_remove_session)


def _get_session(endpoint_uri: URI) -> requests.Session:
cache_key = generate_cache_key(endpoint_uri)
if cache_key not in _session_cache:
_session_cache[cache_key] = requests.Session()
_acquireLock()
try:
cache_key = generate_cache_key(endpoint_uri)
if cache_key not in _session_cache:
_session_cache[cache_key] = requests.Session()
finally:
_releaseLock()
return _session_cache[cache_key]


Expand All @@ -40,3 +83,15 @@ def make_post_request(endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any)
response.raise_for_status()

return response.content


async def async_make_post_request(
endpoint_uri: URI, data: bytes, *args: Any, **kwargs: Any
) -> bytes:
kwargs.setdefault('timeout', ClientTimeout(10))
async with ClientSession(raise_for_status=True) as session:
async with session.post(endpoint_uri,
data=data,
*args,
**kwargs) as response:
return await response.read()