Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit dcb6a37

Browse files
authored
Cap the number of in-flight requests for state from a single group (#11608)
1 parent 7bcc28f commit dcb6a37

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

changelog.d/11608.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Deduplicate in-flight requests in `_get_state_for_groups`.

synapse/storage/databases/state/store.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
logger = logging.getLogger(__name__)
5757

5858
MAX_STATE_DELTA_HOPS = 100
59+
MAX_INFLIGHT_REQUESTS_PER_GROUP = 5
5960

6061

6162
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -258,6 +259,11 @@ def _get_state_for_group_gather_inflight_requests(
258259
Attempts to gather in-flight requests and re-use them to retrieve state
259260
for the given state group, filtered with the given state filter.
260261
262+
If there are more than MAX_INFLIGHT_REQUESTS_PER_GROUP in-flight requests,
263+
and there *still* isn't enough information to complete the request by solely
264+
reusing others, a full state filter will be requested to ensure that subsequent
265+
requests can reuse this request.
266+
261267
Used as part of _get_state_for_group_using_inflight_cache.
262268
263269
Returns:
@@ -288,6 +294,16 @@ def _get_state_for_group_gather_inflight_requests(
288294
# to cover our StateFilter and give us the state we need.
289295
break
290296

297+
if (
298+
state_filter_left_over != StateFilter.none()
299+
and len(inflight_requests) >= MAX_INFLIGHT_REQUESTS_PER_GROUP
300+
):
301+
# There are too many requests for this group.
302+
# To prevent even more from building up, we request the whole
303+
# state filter to guarantee that we can be reused by any subsequent
304+
# requests for this state group.
305+
return (), StateFilter.all()
306+
291307
return reusable_requests, state_filter_left_over
292308

293309
async def _get_state_for_group_fire_request(

tests/storage/databases/test_state_store.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from twisted.test.proto_helpers import MemoryReactor
2020

2121
from synapse.api.constants import EventTypes
22+
from synapse.storage.databases.state.store import MAX_INFLIGHT_REQUESTS_PER_GROUP
2223
from synapse.storage.state import StateFilter
2324
from synapse.types import StateMap
2425
from synapse.util import Clock
@@ -281,3 +282,71 @@ def test_in_flight_requests_stop_being_in_flight(self) -> None:
281282

282283
self.assertEqual(self.get_success(req1), FAKE_STATE)
283284
self.assertEqual(self.get_success(req2), FAKE_STATE)
285+
286+
def test_inflight_requests_capped(self) -> None:
287+
"""
288+
Tests that the number of in-flight requests is capped to 5.
289+
290+
- requests several pieces of state separately
291+
(5 to hit the limit, 1 to 'shunt out', another that comes after the
292+
group has been 'shunted out')
293+
- checks to see that the torrent of requests is shunted out by
294+
rewriting one of the filters as the 'all' state filter
295+
- requests after that one do not cause any additional queries
296+
"""
297+
# 5 at the time of writing.
298+
CAP_COUNT = MAX_INFLIGHT_REQUESTS_PER_GROUP
299+
300+
reqs = []
301+
302+
# Request 7 different keys (1 to 7) of the `some.state` type.
303+
for req_id in range(CAP_COUNT + 2):
304+
reqs.append(
305+
ensureDeferred(
306+
self.state_datastore._get_state_for_group_using_inflight_cache(
307+
42,
308+
StateFilter.freeze(
309+
{"some.state": {str(req_id + 1)}}, include_others=False
310+
),
311+
)
312+
)
313+
)
314+
self.pump(by=0.1)
315+
316+
# There should only be 6 calls to the database, not 7.
317+
self.assertEqual(len(self.get_state_group_calls), CAP_COUNT + 1)
318+
319+
# Assert that the first 5 are exact requests for the individual pieces
320+
# wanted
321+
for req_id in range(CAP_COUNT):
322+
groups, sf, d = self.get_state_group_calls[req_id]
323+
self.assertEqual(
324+
sf,
325+
StateFilter.freeze(
326+
{"some.state": {str(req_id + 1)}}, include_others=False
327+
),
328+
)
329+
330+
# The 6th request should be the 'all' state filter
331+
groups, sf, d = self.get_state_group_calls[CAP_COUNT]
332+
self.assertEqual(sf, StateFilter.all())
333+
334+
# Complete the queries and check which requests complete as a result
335+
for req_id in range(CAP_COUNT):
336+
# This request should not have been completed yet
337+
self.assertFalse(reqs[req_id].called)
338+
339+
groups, sf, d = self.get_state_group_calls[req_id]
340+
self._complete_request_fake(groups, sf, d)
341+
342+
# This should have only completed this one request
343+
self.assertTrue(reqs[req_id].called)
344+
345+
# Now complete the final query; the last 2 requests should complete
346+
# as a result
347+
self.assertFalse(reqs[CAP_COUNT].called)
348+
self.assertFalse(reqs[CAP_COUNT + 1].called)
349+
groups, sf, d = self.get_state_group_calls[CAP_COUNT]
350+
self._complete_request_fake(groups, sf, d)
351+
self.assertTrue(reqs[CAP_COUNT].called)
352+
self.assertTrue(reqs[CAP_COUNT + 1].called)

0 commit comments

Comments
 (0)