Skip to content

Commit 30f76fa

Browse files
committed
Separate batch send from batch receive; cache request funcs:
- Separate the batch send and batch receive functions for persistent connection providers in order to deterministically cache the request information and be able to retrieve it without any request id guesswork. - Remove the idea of a batching request counter / id 🎉 - Correct typing expectations for ``BatchRequestInformation``
1 parent d939727 commit 30f76fa

File tree

13 files changed

+256
-130
lines changed

13 files changed

+256
-130
lines changed

tests/core/providers/test_async_http_provider.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
AsyncWeb3,
1313
__version__ as web3py_version,
1414
)
15-
from web3._utils.batching import (
16-
is_batching_context,
17-
)
1815
from web3.eth import (
1916
AsyncEth,
2017
)
@@ -131,4 +128,4 @@ async def test_async_http_empty_batch_response(mock_async_post):
131128
with pytest.raises(Web3RPCError, match="empty batch"):
132129
await batch.async_execute()
133130

134-
assert not is_batching_context()
131+
assert not async_w3.provider._is_batching

tests/core/providers/test_async_ipc_provider.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@
2020
from web3 import (
2121
AsyncWeb3,
2222
)
23-
from web3._utils.batching import (
24-
is_batching_context,
25-
)
2623
from web3.datastructures import (
2724
AttributeDict,
2825
)
@@ -388,8 +385,8 @@ async def test_persistent_connection_provider_empty_batch_response(
388385
)
389386
)
390387
async with async_w3.batch_requests() as batch:
391-
assert is_batching_context()
388+
assert async_w3.provider._is_batching
392389
await batch.async_execute()
393390

394391
# assert that even though there was an error, we have reset the batching state
395-
assert not is_batching_context()
392+
assert not async_w3.provider._is_batching

tests/core/providers/test_http_provider.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@
1515
Web3,
1616
__version__ as web3py_version,
1717
)
18-
from web3._utils.batching import (
19-
is_batching_context,
20-
)
2118
from web3.eth import (
2219
Eth,
2320
)
@@ -130,9 +127,9 @@ def test_http_empty_batch_response(mock_post):
130127
)
131128
w3 = Web3(HTTPProvider())
132129
with w3.batch_requests() as batch:
133-
assert is_batching_context()
130+
assert w3.provider._is_batching
134131
with pytest.raises(Web3RPCError, match="empty batch"):
135132
batch.execute()
136133

137134
# assert that even though there was an error, we have reset the batching state
138-
assert not is_batching_context()
135+
assert not w3.provider._is_batching

tests/core/providers/test_websocket_provider.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@
1818
from web3 import (
1919
AsyncWeb3,
2020
)
21-
from web3._utils.batching import (
22-
is_batching_context,
23-
)
2421
from web3._utils.caching import (
2522
RequestInformation,
2623
generate_cache_key,
@@ -479,7 +476,7 @@ async def test_persistent_connection_provider_empty_batch_response():
479476
with pytest.raises(Web3RPCError, match="empty batch"):
480477
async with AsyncWeb3(WebSocketProvider("ws://mocked")) as async_w3:
481478
async with async_w3.batch_requests() as batch:
482-
assert is_batching_context()
479+
assert async_w3.provider._is_batching
483480
async_w3.provider._ws.recv = AsyncMock()
484481
async_w3.provider._ws.recv.return_value = (
485482
b'{"jsonrpc": "2.0","id":null,"error": {"code": -32600, '
@@ -489,7 +486,7 @@ async def test_persistent_connection_provider_empty_batch_response():
489486

490487
# assert that even though there was an error, we have reset the batching
491488
# state
492-
assert not is_batching_context()
489+
assert not async_w3.provider._is_batching
493490

494491

495492
@pytest.mark.parametrize(

web3/_utils/batching.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import contextvars
2-
from copy import (
3-
copy,
4-
)
51
from types import (
62
TracebackType,
73
)
@@ -13,8 +9,6 @@
139
Dict,
1410
Generic,
1511
List,
16-
Optional,
17-
Sequence,
1812
Tuple,
1913
Type,
2014
Union,
@@ -29,7 +23,6 @@
2923
Web3ValueError,
3024
)
3125
from web3.types import (
32-
RPCEndpoint,
3326
TFunc,
3427
TReturn,
3528
)
@@ -58,24 +51,14 @@
5851
JSONBaseProvider,
5952
)
6053
from web3.types import ( # noqa: F401
54+
RPCEndpoint,
6155
RPCResponse,
6256
)
6357

6458

6559
BATCH_REQUEST_ID = "batch_request" # for use as the cache key for batch requests
6660

67-
# control batching context via a context var
68-
_batching_context: contextvars.ContextVar[
69-
Optional["RequestBatcher[Any]"]
70-
] = contextvars.ContextVar("batching_context", default=None)
71-
72-
73-
def is_batching_context() -> bool:
74-
"""Check if we're currently in a batching context."""
75-
return _batching_context.get() is not None
76-
77-
78-
BatchRequestInformation = Tuple[Tuple["RPCEndpoint", Any], Sequence[Any]]
61+
BatchRequestInformation = Tuple[Tuple["RPCEndpoint", Any], Tuple[Any, ...]]
7962
RPC_METHODS_UNSUPPORTED_DURING_BATCH = {
8063
"eth_subscribe",
8164
"eth_unsubscribe",
@@ -105,22 +88,19 @@ def _provider(self) -> Union["JSONBaseProvider", "AsyncJSONBaseProvider"]:
10588
)
10689

10790
def _validate_is_batching(self) -> None:
108-
if not is_batching_context():
91+
if not self._provider._is_batching:
10992
raise Web3ValueError(
11093
"Batch has already been executed or cancelled. Create a new batch to "
11194
"issue batched requests."
11295
)
11396

11497
def _initialize_batching(self) -> None:
115-
_batching_context.set(self)
98+
self._provider._batching_context.set(self)
11699
self.clear()
117100

118101
def _end_batching(self) -> None:
119-
_batching_context.set(None)
120102
self.clear()
121-
if self._provider.has_persistent_connection:
122-
provider = cast("PersistentConnectionProvider", self._provider)
123-
provider._batch_request_counter = None
103+
self._provider._batching_context.set(None)
124104

125105
def add(self, batch_payload: TReturn) -> None:
126106
self._validate_is_batching()
@@ -164,9 +144,6 @@ def execute(self) -> List["RPCResponse"]:
164144
def clear(self) -> None:
165145
self._requests_info = []
166146
self._async_requests_info = []
167-
if self._provider.has_persistent_connection:
168-
provider = cast("PersistentConnectionProvider", self._provider)
169-
provider._batch_request_counter = next(copy(provider.request_counter))
170147

171148
def cancel(self) -> None:
172149
self._end_batching()
@@ -189,9 +166,14 @@ def __exit__(
189166

190167
async def async_execute(self) -> List["RPCResponse"]:
191168
self._validate_is_batching()
192-
responses = await self.web3.manager._async_make_batch_request(
193-
self._async_requests_info
194-
)
169+
if self._provider.has_persistent_connection:
170+
responses = await self.web3.manager._async_make_socket_batch_request(
171+
self._async_requests_info
172+
)
173+
else:
174+
responses = await self.web3.manager._async_make_batch_request(
175+
self._async_requests_info
176+
)
195177
self._end_batching()
196178
return responses
197179

web3/_utils/module_testing/web3_module.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@
2525
AsyncWeb3,
2626
Web3,
2727
)
28-
from web3._utils.batching import (
29-
is_batching_context,
30-
)
3128
from web3._utils.ens import (
3229
ens_addresses,
3330
)
@@ -345,7 +342,7 @@ def test_batch_requests(self, w3: "Web3", math_contract: Contract) -> None:
345342

346343
# assert proper batch cleanup after execution
347344
assert batch._requests_info == []
348-
assert not is_batching_context()
345+
assert not w3.provider._is_batching
349346

350347
# assert batch cannot be added to after execution
351348
with pytest.raises(
@@ -404,7 +401,7 @@ def test_batch_requests_initialized_as_object(
404401

405402
# assert proper batch cleanup after execution
406403
assert batch._requests_info == []
407-
assert not is_batching_context()
404+
assert not w3.provider._is_batching
408405

409406
# assert batch cannot be added to after execution
410407
with pytest.raises(
@@ -597,7 +594,7 @@ async def test_batch_requests( # type: ignore[override]
597594

598595
# assert proper batch cleanup after execution
599596
assert batch._async_requests_info == []
600-
assert not is_batching_context()
597+
assert not async_w3.provider._is_batching
601598

602599
# assert batch cannot be added to after execution
603600
with pytest.raises(
@@ -660,7 +657,7 @@ async def test_batch_requests_initialized_as_object( # type: ignore[override]
660657

661658
# assert proper batch cleanup after execution
662659
assert batch._async_requests_info == []
663-
assert not is_batching_context()
660+
assert not async_w3.provider._is_batching
664661

665662
# assert batch cannot be added to after execution
666663
with pytest.raises(
@@ -785,19 +782,20 @@ async def test_batch_requests_raises_for_common_unsupported_methods( # type: ig
785782
async def test_batch_requests_concurrently_with_regular_requests( # type: ignore[override] # noqa: E501
786783
self, async_w3: AsyncWeb3 # type: ignore[override]
787784
) -> None:
788-
num_requests = 40
789785
responses = []
790786
batch_response = []
791787

788+
num_blocks = await async_w3.eth.block_number
789+
792790
async def make_regular_requests() -> None:
793-
for _ in range(num_requests):
794-
responses.append(await async_w3.eth.get_block(0))
791+
for i in range(num_blocks):
792+
responses.append(await async_w3.eth.get_block(i))
795793
await asyncio.sleep(0.01)
796794

797795
async def make_batch_request() -> None:
798796
async with async_w3.batch_requests() as batch:
799-
for _ in range(num_requests):
800-
batch.add(async_w3.eth.get_block(0))
797+
for i in range(num_blocks):
798+
batch.add(async_w3.eth.get_block(i))
801799
await asyncio.sleep(0.01)
802800
batch_response.extend(await batch.async_execute())
803801

@@ -806,7 +804,7 @@ async def make_batch_request() -> None:
806804
make_batch_request(),
807805
)
808806

809-
assert len(responses) == num_requests
810-
assert len(batch_response) == num_requests
807+
assert len(responses) == num_blocks
808+
assert len(batch_response) == num_blocks
811809
assert all(SOME_BLOCK_KEYS.issubset(response.keys()) for response in responses)
812810
assert set(responses) == set(batch_response)

web3/contract/utils.py

Lines changed: 22 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@
4444
async_fill_transaction_defaults,
4545
)
4646
from web3._utils.batching import (
47-
is_batching_context,
47+
BatchRequestInformation,
4848
)
4949
from web3._utils.compat import (
5050
TypeAlias,
@@ -178,10 +178,8 @@ def call_contract_function(
178178
if abi_callable["type"] == "function":
179179
output_types = get_abi_output_types(abi_callable)
180180

181-
w3.provider
182-
if is_batching_context():
183-
BatchingReturnData: TypeAlias = Tuple[Tuple[RPCEndpoint, Any], Tuple[Any, ...]]
184-
request_information = tuple(cast(BatchingReturnData, return_data))
181+
if w3.provider._is_batching:
182+
request_information = tuple(cast(BatchRequestInformation, return_data))
185183
method_and_params = request_information[0]
186184

187185
# append return data formatting to result formatters
@@ -477,7 +475,7 @@ async def async_call_contract_function(
477475
if fn_abi["type"] == "function":
478476
output_types = get_abi_output_types(fn_abi)
479477

480-
if is_batching_context():
478+
if async_w3.provider._is_batching:
481479
contract_call_return_data_formatter = format_contract_call_return_data_curried(
482480
async_w3,
483481
decode_tuples,
@@ -486,33 +484,24 @@ async def async_call_contract_function(
486484
normalizers,
487485
output_types,
488486
)
489-
if async_w3.provider.has_persistent_connection:
490-
# get the current request id
491-
provider = cast("PersistentConnectionProvider", async_w3.provider)
492-
current_request_id = provider._batch_request_counter - 1
493-
provider._request_processor.append_result_formatter_for_request(
494-
current_request_id, contract_call_return_data_formatter
495-
)
496-
else:
497-
BatchingReturnData: TypeAlias = Tuple[
498-
Tuple[RPCEndpoint, Any], Tuple[Any, ...]
499-
]
500-
request_information = tuple(cast(BatchingReturnData, return_data))
501-
method_and_params = request_information[0]
502-
503-
# append return data formatter to result formatters
504-
current_response_formatters = request_information[1]
505-
current_result_formatters = current_response_formatters[0]
506-
updated_result_formatters = compose(
507-
contract_call_return_data_formatter,
508-
current_result_formatters,
509-
)
510-
response_formatters = (
511-
updated_result_formatters, # result formatters
512-
current_response_formatters[1], # error formatters
513-
current_response_formatters[2], # null result formatters
514-
)
515-
return (method_and_params, response_formatters)
487+
488+
BatchingReturnData: TypeAlias = Tuple[Tuple[RPCEndpoint, Any], Tuple[Any, ...]]
489+
request_information = tuple(cast(BatchingReturnData, return_data))
490+
method_and_params = request_information[0]
491+
492+
# append return data formatter to result formatters
493+
current_response_formatters = request_information[1]
494+
current_result_formatters = current_response_formatters[0]
495+
updated_result_formatters = compose(
496+
contract_call_return_data_formatter,
497+
current_result_formatters,
498+
)
499+
response_formatters = (
500+
updated_result_formatters, # result formatters
501+
current_response_formatters[1], # error formatters
502+
current_response_formatters[2], # null result formatters
503+
)
504+
return (method_and_params, response_formatters)
516505

517506
return return_data
518507

0 commit comments

Comments
 (0)