Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Commit b83bc5f

Browse files
authored
Pull out less state when handling gaps mk2 (#12852)
1 parent 1b33847 commit b83bc5f

File tree

8 files changed

+236
-127
lines changed

8 files changed

+236
-127
lines changed

changelog.d/12852.misc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Pull out less state when handling gaps in room DAG.

synapse/handlers/federation_event.py

Lines changed: 84 additions & 94 deletions
Large diffs are not rendered by default.

synapse/handlers/message.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,14 @@
5555
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
5656
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
5757
from synapse.storage.state import StateFilter
58-
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
58+
from synapse.types import (
59+
MutableStateMap,
60+
Requester,
61+
RoomAlias,
62+
StreamToken,
63+
UserID,
64+
create_requester,
65+
)
5966
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
6067
from synapse.util.async_helpers import Linearizer, gather_results
6168
from synapse.util.caches.expiringcache import ExpiringCache
@@ -1022,8 +1029,35 @@ async def create_new_client_event(
10221029
#
10231030
# TODO(faster_joins): figure out how this works, and make sure that the
10241031
# old state is complete.
1025-
old_state = await self.store.get_events_as_list(state_event_ids)
1026-
context = await self.state.compute_event_context(event, old_state=old_state)
1032+
metadata = await self.store.get_metadata_for_events(state_event_ids)
1033+
1034+
state_map_for_event: MutableStateMap[str] = {}
1035+
for state_id in state_event_ids:
1036+
data = metadata.get(state_id)
1037+
if data is None:
1038+
# We're trying to persist a new historical batch of events
1039+
# with the given state, e.g. via
1040+
# `RoomBatchSendEventRestServlet`. The state can be inferred
1041+
# by Synapse or set directly by the client.
1042+
#
1043+
# Either way, we should have persisted all the state before
1044+
# getting here.
1045+
raise Exception(
1046+
f"State event {state_id} not found in DB,"
1047+
" Synapse should have persisted it before using it."
1048+
)
1049+
1050+
if data.state_key is None:
1051+
raise Exception(
1052+
f"Trying to set non-state event {state_id} as state"
1053+
)
1054+
1055+
state_map_for_event[(data.event_type, data.state_key)] = state_id
1056+
1057+
context = await self.state.compute_event_context(
1058+
event,
1059+
state_ids_before_event=state_map_for_event,
1060+
)
10271061
else:
10281062
context = await self.state.compute_event_context(event)
10291063

synapse/state/__init__.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ async def get_hosts_in_room_at_events(
261261
async def compute_event_context(
262262
self,
263263
event: EventBase,
264-
old_state: Optional[Iterable[EventBase]] = None,
264+
state_ids_before_event: Optional[StateMap[str]] = None,
265265
partial_state: bool = False,
266266
) -> EventContext:
267267
"""Build an EventContext structure for a non-outlier event.
@@ -273,26 +273,24 @@ async def compute_event_context(
273273
274274
Args:
275275
event:
276-
old_state: The state at the event if it can't be
277-
calculated from existing events. This is normally only specified
278-
when receiving an event from federation where we don't have the
279-
prev events for, e.g. when backfilling.
280-
partial_state: True if `old_state` is partial and omits non-critical
281-
membership events
276+
state_ids_before_event: The event ids of the state before the event if
277+
it can't be calculated from existing events. This is normally
278+
only specified when receiving an event from federation where we
279+
don't have the prev events, e.g. when backfilling.
280+
partial_state: True if `state_ids_before_event` is partial and omits
281+
non-critical membership events
282282
Returns:
283283
The event context.
284284
"""
285285

286286
assert not event.internal_metadata.is_outlier()
287287

288288
#
289-
# first of all, figure out the state before the event
289+
# first of all, figure out the state before the event, unless we
290+
# already have it.
290291
#
291-
if old_state:
292+
if state_ids_before_event:
292293
# if we're given the state before the event, then we use that
293-
state_ids_before_event: StateMap[str] = {
294-
(s.type, s.state_key): s.event_id for s in old_state
295-
}
296294
state_group_before_event = None
297295
state_group_before_event_prev_group = None
298296
deltas_to_state_group_before_event = None

synapse/storage/databases/main/state.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import logging
1717
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
1818

19+
import attr
20+
1921
from synapse.api.constants import EventTypes, Membership
2022
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
2123
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -26,13 +28,15 @@
2628
DatabasePool,
2729
LoggingDatabaseConnection,
2830
LoggingTransaction,
31+
make_in_list_sql_clause,
2932
)
3033
from synapse.storage.databases.main.events_worker import EventsWorkerStore
3134
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
3235
from synapse.storage.state import StateFilter
3336
from synapse.types import JsonDict, JsonMapping, StateMap
3437
from synapse.util.caches import intern_string
3538
from synapse.util.caches.descriptors import cached, cachedList
39+
from synapse.util.iterutils import batch_iter
3640

3741
if TYPE_CHECKING:
3842
from synapse.server import HomeServer
@@ -43,6 +47,15 @@
4347
MAX_STATE_DELTA_HOPS = 100
4448

4549

50+
@attr.s(slots=True, frozen=True, auto_attribs=True)
51+
class EventMetadata:
52+
"""Returned by `get_metadata_for_events`"""
53+
54+
room_id: str
55+
event_type: str
56+
state_key: Optional[str]
57+
58+
4659
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
4760
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
4861
if not v:
@@ -133,6 +146,52 @@ def get_room_version_id_txn(self, txn: LoggingTransaction, room_id: str) -> str:
133146

134147
return room_version
135148

149+
async def get_metadata_for_events(
150+
self, event_ids: Collection[str]
151+
) -> Dict[str, EventMetadata]:
152+
"""Get some metadata (room_id, type, state_key) for the given events.
153+
154+
This method is a faster alternative than fetching the full events from
155+
the DB, and should be used when the full event is not needed.
156+
157+
Returns metadata for rejected and redacted events. Events that have not
158+
been persisted are omitted from the returned dict.
159+
"""
160+
161+
def get_metadata_for_events_txn(
162+
txn: LoggingTransaction,
163+
batch_ids: Collection[str],
164+
) -> Dict[str, EventMetadata]:
165+
clause, args = make_in_list_sql_clause(
166+
self.database_engine, "e.event_id", batch_ids
167+
)
168+
169+
sql = f"""
170+
SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
171+
LEFT JOIN state_events USING (event_id)
172+
WHERE {clause}
173+
"""
174+
175+
txn.execute(sql, args)
176+
return {
177+
event_id: EventMetadata(
178+
room_id=room_id, event_type=event_type, state_key=state_key
179+
)
180+
for event_id, room_id, event_type, state_key in txn
181+
}
182+
183+
result_map: Dict[str, EventMetadata] = {}
184+
for batch_ids in batch_iter(event_ids, 1000):
185+
result_map.update(
186+
await self.db_pool.runInteraction(
187+
"get_metadata_for_events",
188+
get_metadata_for_events_txn,
189+
batch_ids=batch_ids,
190+
)
191+
)
192+
193+
return result_map
194+
136195
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
137196
"""Get the predecessor of an upgraded room if it exists.
138197
Otherwise return None.

tests/handlers/test_federation.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,11 @@ def test_backfill_with_many_backward_extremities(self) -> None:
276276
# federation handler wanting to backfill the fake event.
277277
self.get_success(
278278
federation_event_handler._process_received_pdu(
279-
self.OTHER_SERVER_NAME, event, state=current_state
279+
self.OTHER_SERVER_NAME,
280+
event,
281+
state_ids={
282+
(e.type, e.state_key): e.event_id for e in current_state
283+
},
280284
)
281285
)
282286

tests/storage/test_events.py

Lines changed: 28 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def prepare(self, reactor, clock, homeserver):
6969
def persist_event(self, event, state=None):
7070
"""Persist the event, with optional state"""
7171
context = self.get_success(
72-
self.state.compute_event_context(event, old_state=state)
72+
self.state.compute_event_context(event, state_ids_before_event=state)
7373
)
7474
self.get_success(self.persistence.persist_event(event, context))
7575

@@ -103,9 +103,11 @@ def test_prune_gap(self):
103103
RoomVersions.V6,
104104
)
105105

106-
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
106+
state_before_gap = self.get_success(
107+
self.state.get_current_state_ids(self.room_id)
108+
)
107109

108-
self.persist_event(remote_event_2, state=state_before_gap.values())
110+
self.persist_event(remote_event_2, state=state_before_gap)
109111

110112
# Check the new extremity is just the new remote event.
111113
self.assert_extremities([remote_event_2.event_id])
@@ -135,13 +137,14 @@ def test_do_not_prune_gap_if_state_different(self):
135137
# setting. The state resolution across the old and new event will then
136138
# include it, and so the resolved state won't match the new state.
137139
state_before_gap = dict(
138-
self.get_success(self.state.get_current_state(self.room_id))
140+
self.get_success(self.state.get_current_state_ids(self.room_id))
139141
)
140142
state_before_gap.pop(("m.room.history_visibility", ""))
141143

142144
context = self.get_success(
143145
self.state.compute_event_context(
144-
remote_event_2, old_state=state_before_gap.values()
146+
remote_event_2,
147+
state_ids_before_event=state_before_gap,
145148
)
146149
)
147150

@@ -177,9 +180,11 @@ def test_prune_gap_if_old(self):
177180
RoomVersions.V6,
178181
)
179182

180-
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
183+
state_before_gap = self.get_success(
184+
self.state.get_current_state_ids(self.room_id)
185+
)
181186

182-
self.persist_event(remote_event_2, state=state_before_gap.values())
187+
self.persist_event(remote_event_2, state=state_before_gap)
183188

184189
# Check the new extremity is just the new remote event.
185190
self.assert_extremities([remote_event_2.event_id])
@@ -207,9 +212,11 @@ def test_do_not_prune_gap_if_other_server(self):
207212
RoomVersions.V6,
208213
)
209214

210-
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
215+
state_before_gap = self.get_success(
216+
self.state.get_current_state_ids(self.room_id)
217+
)
211218

212-
self.persist_event(remote_event_2, state=state_before_gap.values())
219+
self.persist_event(remote_event_2, state=state_before_gap)
213220

214221
# Check the new extremity is just the new remote event.
215222
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -247,9 +254,11 @@ def test_prune_gap_if_dummy_remote(self):
247254
RoomVersions.V6,
248255
)
249256

250-
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
257+
state_before_gap = self.get_success(
258+
self.state.get_current_state_ids(self.room_id)
259+
)
251260

252-
self.persist_event(remote_event_2, state=state_before_gap.values())
261+
self.persist_event(remote_event_2, state=state_before_gap)
253262

254263
# Check the new extremity is just the new remote event.
255264
self.assert_extremities([remote_event_2.event_id])
@@ -289,9 +298,11 @@ def test_prune_gap_if_dummy_local(self):
289298
RoomVersions.V6,
290299
)
291300

292-
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
301+
state_before_gap = self.get_success(
302+
self.state.get_current_state_ids(self.room_id)
303+
)
293304

294-
self.persist_event(remote_event_2, state=state_before_gap.values())
305+
self.persist_event(remote_event_2, state=state_before_gap)
295306

296307
# Check the new extremity is just the new remote event.
297308
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
@@ -323,9 +334,11 @@ def test_do_not_prune_gap_if_not_dummy(self):
323334
RoomVersions.V6,
324335
)
325336

326-
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
337+
state_before_gap = self.get_success(
338+
self.state.get_current_state_ids(self.room_id)
339+
)
327340

328-
self.persist_event(remote_event_2, state=state_before_gap.values())
341+
self.persist_event(remote_event_2, state=state_before_gap)
329342

330343
# Check the new extremity is just the new remote event.
331344
self.assert_extremities([local_message_event_id, remote_event_2.event_id])

tests/test_state.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,12 @@ def test_annotate_with_old_message(self):
442442
]
443443

444444
context = yield defer.ensureDeferred(
445-
self.state.compute_event_context(event, old_state=old_state)
445+
self.state.compute_event_context(
446+
event,
447+
state_ids_before_event={
448+
(e.type, e.state_key): e.event_id for e in old_state
449+
},
450+
)
446451
)
447452

448453
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
@@ -467,7 +472,12 @@ def test_annotate_with_old_state(self):
467472
]
468473

469474
context = yield defer.ensureDeferred(
470-
self.state.compute_event_context(event, old_state=old_state)
475+
self.state.compute_event_context(
476+
event,
477+
state_ids_before_event={
478+
(e.type, e.state_key): e.event_id for e in old_state
479+
},
480+
)
471481
)
472482

473483
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())

0 commit comments

Comments
 (0)