Skip to content

Commit 76ce961

Browse files
committed
Add support for async_simple_cache_middleware
1 parent 3c00810 commit 76ce961

File tree

3 files changed

+270
-0
lines changed

3 files changed

+270
-0
lines changed

tests/core/middleware/test_simple_cache_middleware.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,22 @@
1111
construct_result_generator_middleware,
1212
construct_simple_cache_middleware,
1313
)
14+
from web3.middleware.async_middleware.async_cache import (
15+
async_construct_simple_cache_middleware,
16+
)
17+
from web3.middleware.fixture import (
18+
async_construct_error_generator_middleware,
19+
async_construct_result_generator_middleware,
20+
)
1421
from web3.providers.base import (
1522
BaseProvider,
1623
)
24+
from web3.providers.eth_tester import (
25+
AsyncEthereumTesterProvider,
26+
)
27+
from web3.types import (
28+
RPCEndpoint,
29+
)
1730

1831

1932
@pytest.fixture
@@ -133,3 +146,153 @@ def test_simple_cache_middleware_does_not_cache_endpoints_not_in_whitelist(w3):
133146
result_b = w3.manager.request_blocking("not_whitelisted", [])
134147

135148
assert result_a != result_b
149+
150+
151+
# -- async -- #
152+
153+
154+
async def _async_simple_cache_middleware_for_testing(make_request, async_w3):
155+
middleware = await async_construct_simple_cache_middleware(
156+
cache_class=dict,
157+
rpc_whitelist={RPCEndpoint("fake_endpoint")},
158+
)
159+
return await middleware(make_request, async_w3)
160+
161+
162+
@pytest.fixture
163+
def async_w3():
164+
return Web3(
165+
provider=AsyncEthereumTesterProvider(),
166+
middlewares=[
167+
(_async_simple_cache_middleware_for_testing, "simple_cache"),
168+
],
169+
)
170+
171+
172+
@pytest.mark.asyncio
173+
async def test_async_simple_cache_middleware_pulls_from_cache(async_w3):
174+
async_w3.middleware_onion.remove("simple_cache")
175+
176+
def cache_class():
177+
return {
178+
generate_cache_key(("fake_endpoint", [1])): {"result": "value-a"},
179+
}
180+
181+
async def _properly_awaited_middleware(make_request, _async_w3):
182+
middleware = await async_construct_simple_cache_middleware(
183+
cache_class=cache_class,
184+
rpc_whitelist={RPCEndpoint("fake_endpoint")},
185+
)
186+
return await middleware(make_request, _async_w3)
187+
188+
async_w3.middleware_onion.inject(
189+
_properly_awaited_middleware,
190+
"for_this_test_only",
191+
layer=0,
192+
)
193+
194+
_result = await async_w3.manager.coro_request("fake_endpoint", [1])
195+
assert _result == "value-a"
196+
197+
# cleanup
198+
async_w3.middleware_onion.remove("for_this_test_only")
199+
async_w3.middleware_onion.add(
200+
_async_simple_cache_middleware_for_testing, "simple_cache"
201+
)
202+
203+
204+
@pytest.mark.asyncio
205+
async def test_async_simple_cache_middleware_populates_cache(async_w3):
206+
async_w3.middleware_onion.inject(
207+
await async_construct_result_generator_middleware(
208+
{
209+
RPCEndpoint("fake_endpoint"): lambda *_: str(uuid.uuid4()),
210+
}
211+
),
212+
"result_generator",
213+
layer=0,
214+
)
215+
216+
result = await async_w3.manager.coro_request("fake_endpoint", [])
217+
218+
_empty_params = await async_w3.manager.coro_request("fake_endpoint", [])
219+
_non_empty_params = await async_w3.manager.coro_request("fake_endpoint", [1])
220+
221+
assert _empty_params == result
222+
assert _non_empty_params != result
223+
224+
# cleanup
225+
async_w3.middleware_onion.remove("result_generator")
226+
227+
228+
@pytest.mark.asyncio
229+
async def test_async_simple_cache_middleware_does_not_cache_none_responses(async_w3):
230+
counter = itertools.count()
231+
232+
def result_cb(method, params):
233+
next(counter)
234+
return None
235+
236+
async_w3.middleware_onion.inject(
237+
await async_construct_result_generator_middleware(
238+
{
239+
RPCEndpoint("fake_endpoint"): result_cb,
240+
},
241+
),
242+
"result_generator",
243+
layer=0,
244+
)
245+
246+
await async_w3.manager.coro_request("fake_endpoint", [])
247+
await async_w3.manager.coro_request("fake_endpoint", [])
248+
249+
assert next(counter) == 2
250+
251+
# cleanup
252+
async_w3.middleware_onion.remove("result_generator")
253+
254+
255+
@pytest.mark.asyncio
256+
async def test_async_simple_cache_middleware_does_not_cache_error_responses(async_w3):
257+
async_w3.middleware_onion.inject(
258+
await async_construct_error_generator_middleware(
259+
{
260+
RPCEndpoint("fake_endpoint"): lambda *_: f"msg-{uuid.uuid4()}",
261+
}
262+
),
263+
"error_generator",
264+
layer=0,
265+
)
266+
267+
with pytest.raises(ValueError) as err_a:
268+
await async_w3.manager.coro_request("fake_endpoint", [])
269+
with pytest.raises(ValueError) as err_b:
270+
await async_w3.manager.coro_request("fake_endpoint", [])
271+
272+
assert str(err_a) != str(err_b)
273+
274+
# cleanup
275+
async_w3.middleware_onion.remove("error_generator")
276+
277+
278+
@pytest.mark.asyncio
279+
async def test_async_simple_cache_middleware_does_not_cache_non_whitelist_endpoints(
280+
async_w3,
281+
):
282+
async_w3.middleware_onion.inject(
283+
await async_construct_result_generator_middleware(
284+
{
285+
RPCEndpoint("not_whitelisted"): lambda *_: str(uuid.uuid4()),
286+
}
287+
),
288+
"result_generator",
289+
layer=0,
290+
)
291+
292+
result_a = await async_w3.manager.coro_request("not_whitelisted", [])
293+
result_b = await async_w3.manager.coro_request("not_whitelisted", [])
294+
295+
assert result_a != result_b
296+
297+
# cleanup
298+
async_w3.middleware_onion.remove("result_generator")

web3/middleware/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from .abi import ( # noqa: F401
1515
abi_middleware,
1616
)
17+
from .async_middleware.async_cache import ( # noqa: F401
18+
_async_simple_cache_middleware as async_simple_cache_middleware,
19+
)
1720
from .attrdict import ( # noqa: F401
1821
attrdict_middleware,
1922
)
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
import functools
2+
import threading
3+
from typing import (
4+
TYPE_CHECKING,
5+
Any,
6+
Callable,
7+
Collection,
8+
Dict,
9+
Tuple,
10+
Type,
11+
cast,
12+
)
13+
14+
import lru
15+
16+
from web3._utils.caching import (
17+
generate_cache_key,
18+
)
19+
from web3.types import (
20+
AsyncMiddleware,
21+
Middleware,
22+
RPCEndpoint,
23+
RPCResponse,
24+
)
25+
26+
if TYPE_CHECKING:
27+
from web3 import Web3 # noqa: F401
28+
29+
30+
SIMPLE_CACHE_RPC_WHITELIST = cast(
31+
Tuple[RPCEndpoint],
32+
(
33+
"web3_clientVersion",
34+
"eth_getBlockTransactionCountByHash",
35+
"eth_getUncleCountByBlockHash",
36+
"eth_getBlockByHash",
37+
"eth_getTransactionByHash",
38+
"eth_getTransactionByBlockHashAndIndex",
39+
"eth_getRawTransactionByHash",
40+
"eth_getUncleByBlockHashAndIndex",
41+
"eth_chainId",
42+
),
43+
)
44+
45+
46+
def _should_cache_response(
47+
_method: RPCEndpoint, _params: Any, response: RPCResponse
48+
) -> bool:
49+
return (
50+
"error" not in response
51+
and "result" in response
52+
and response["result"] is not None
53+
)
54+
55+
56+
async def async_construct_simple_cache_middleware(
57+
cache_class: Type[Dict[Any, Any]],
58+
rpc_whitelist: Collection[RPCEndpoint] = SIMPLE_CACHE_RPC_WHITELIST,
59+
should_cache_fn: Callable[
60+
[RPCEndpoint, Any, RPCResponse], bool
61+
] = _should_cache_response,
62+
) -> Middleware:
63+
"""
64+
Constructs a middleware which caches responses based on the request
65+
``method`` and ``params``
66+
67+
:param cache_class: Any dictionary-like object
68+
:param rpc_whitelist: A set of RPC methods which may have their responses cached.
69+
:param should_cache_fn: A callable which accepts ``method`` ``params`` and
70+
``response`` and returns a boolean as to whether the response should be
71+
cached.
72+
"""
73+
74+
async def async_simple_cache_middleware(
75+
make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "Web3"
76+
) -> AsyncMiddleware:
77+
cache = cache_class()
78+
lock = threading.Lock()
79+
80+
async def middleware(method: RPCEndpoint, params: Any) -> RPCResponse:
81+
if method in rpc_whitelist:
82+
with lock:
83+
cache_key = generate_cache_key((method, params))
84+
if cache_key not in cache:
85+
response = await make_request(method, params)
86+
if should_cache_fn(method, params, response):
87+
cache[cache_key] = response
88+
return response
89+
return cache[cache_key]
90+
else:
91+
return await make_request(method, params)
92+
93+
return middleware
94+
95+
return async_simple_cache_middleware
96+
97+
98+
async def _async_simple_cache_middleware(
99+
make_request: Callable[[RPCEndpoint, Any], Any], async_w3: "Web3"
100+
):
101+
middleware = await async_construct_simple_cache_middleware(
102+
cache_class=cast(Type[Dict[Any, Any]], functools.partial(lru.LRU, 256)),
103+
)
104+
return await middleware(make_request, async_w3)

0 commit comments

Comments
 (0)