5454 Callable ,
5555 Collection ,
5656 Dict ,
57+ Iterable ,
5758 List ,
5859 Optional ,
5960 Set ,
61+ Tuple ,
6062)
6163
62- from synapse .appservice import ApplicationService , ApplicationServiceState
64+ from synapse .appservice import (
65+ ApplicationService ,
66+ ApplicationServiceState ,
67+ TransactionOneTimeKeyCounts ,
68+ TransactionUnusedFallbackKeys ,
69+ )
6370from synapse .appservice .api import ApplicationServiceApi
6471from synapse .events import EventBase
6572from synapse .logging .context import run_in_background
@@ -96,7 +103,7 @@ def __init__(self, hs: "HomeServer"):
96103 self .as_api = hs .get_application_service_api ()
97104
98105 self .txn_ctrl = _TransactionController (self .clock , self .store , self .as_api )
99- self .queuer = _ServiceQueuer (self .txn_ctrl , self .clock )
106+ self .queuer = _ServiceQueuer (self .txn_ctrl , self .clock , hs )
100107
101108 async def start (self ) -> None :
102109 logger .info ("Starting appservice scheduler" )
@@ -153,7 +160,9 @@ class _ServiceQueuer:
153160 appservice at a given time.
154161 """
155162
156- def __init__ (self , txn_ctrl : "_TransactionController" , clock : Clock ):
163+ def __init__ (
164+ self , txn_ctrl : "_TransactionController" , clock : Clock , hs : "HomeServer"
165+ ):
157166 # dict of {service_id: [events]}
158167 self .queued_events : Dict [str , List [EventBase ]] = {}
159168 # dict of {service_id: [events]}
@@ -165,6 +174,10 @@ def __init__(self, txn_ctrl: "_TransactionController", clock: Clock):
165174 self .requests_in_flight : Set [str ] = set ()
166175 self .txn_ctrl = txn_ctrl
167176 self .clock = clock
177+ self ._msc3202_transaction_extensions_enabled : bool = (
178+ hs .config .experimental .msc3202_transaction_extensions
179+ )
180+ self ._store = hs .get_datastores ().main
168181
169182 def start_background_request (self , service : ApplicationService ) -> None :
170183 # start a sender for this appservice if we don't already have one
@@ -202,15 +215,84 @@ async def _send_request(self, service: ApplicationService) -> None:
202215 if not events and not ephemeral and not to_device_messages_to_send :
203216 return
204217
218+ one_time_key_counts : Optional [TransactionOneTimeKeyCounts ] = None
219+ unused_fallback_keys : Optional [TransactionUnusedFallbackKeys ] = None
220+
221+ if (
222+ self ._msc3202_transaction_extensions_enabled
223+ and service .msc3202_transaction_extensions
224+ ):
225+ # Compute the one-time key counts and fallback key usage states
226+ # for the users which are mentioned in this transaction,
227+ # as well as the appservice's sender.
228+ (
229+ one_time_key_counts ,
230+ unused_fallback_keys ,
231+ ) = await self ._compute_msc3202_otk_counts_and_fallback_keys (
232+ service , events , ephemeral , to_device_messages_to_send
233+ )
234+
205235 try :
206236 await self .txn_ctrl .send (
207- service , events , ephemeral , to_device_messages_to_send
237+ service ,
238+ events ,
239+ ephemeral ,
240+ to_device_messages_to_send ,
241+ one_time_key_counts ,
242+ unused_fallback_keys ,
208243 )
209244 except Exception :
210245 logger .exception ("AS request failed" )
211246 finally :
212247 self .requests_in_flight .discard (service .id )
213248
249+ async def _compute_msc3202_otk_counts_and_fallback_keys (
250+ self ,
251+ service : ApplicationService ,
252+ events : Iterable [EventBase ],
253+ ephemerals : Iterable [JsonDict ],
254+ to_device_messages : Iterable [JsonDict ],
255+ ) -> Tuple [TransactionOneTimeKeyCounts , TransactionUnusedFallbackKeys ]:
256+ """
257+ Given a list of the events, ephemeral messages and to-device messages,
258+ - first computes a list of application services users that may have
259+ interesting updates to the one-time key counts or fallback key usage.
260+ - then computes one-time key counts and fallback key usages for those users.
261+ Given a list of application service users that are interesting,
262+ compute one-time key counts and fallback key usages for the users.
263+ """
264+
265+ # Set of 'interesting' users who may have updates
266+ users : Set [str ] = set ()
267+
268+ # The sender is always included
269+ users .add (service .sender )
270+
271+ # All AS users that would receive the PDUs or EDUs sent to these rooms
272+ # are classed as 'interesting'.
273+ rooms_of_interesting_users : Set [str ] = set ()
274+ # PDUs
275+ rooms_of_interesting_users .update (event .room_id for event in events )
276+ # EDUs
277+ rooms_of_interesting_users .update (
278+ ephemeral ["room_id" ] for ephemeral in ephemerals
279+ )
280+
281+ # Look up the AS users in those rooms
282+ for room_id in rooms_of_interesting_users :
283+ users .update (
284+ await self ._store .get_app_service_users_in_room (room_id , service )
285+ )
286+
287+ # Add recipients of to-device messages.
288+ # device_message["user_id"] is the ID of the recipient.
289+ users .update (device_message ["user_id" ] for device_message in to_device_messages )
290+
291+ # Compute and return the counts / fallback key usage states
292+ otk_counts = await self ._store .count_bulk_e2e_one_time_keys_for_as (users )
293+ unused_fbks = await self ._store .get_e2e_bulk_unused_fallback_key_types (users )
294+ return otk_counts , unused_fbks
295+
214296
215297class _TransactionController :
216298 """Transaction manager.
@@ -238,6 +320,8 @@ async def send(
238320 events : List [EventBase ],
239321 ephemeral : Optional [List [JsonDict ]] = None ,
240322 to_device_messages : Optional [List [JsonDict ]] = None ,
323+ one_time_key_counts : Optional [TransactionOneTimeKeyCounts ] = None ,
324+ unused_fallback_keys : Optional [TransactionUnusedFallbackKeys ] = None ,
241325 ) -> None :
242326 """
243327 Create a transaction with the given data and send to the provided
@@ -248,13 +332,19 @@ async def send(
248332 events: The persistent events to include in the transaction.
249333 ephemeral: The ephemeral events to include in the transaction.
250334 to_device_messages: The to-device messages to include in the transaction.
335+ one_time_key_counts: Counts of remaining one-time keys for relevant
336+ appservice devices in the transaction.
337+ unused_fallback_keys: Lists of unused fallback keys for relevant
338+ appservice devices in the transaction.
251339 """
252340 try :
253341 txn = await self .store .create_appservice_txn (
254342 service = service ,
255343 events = events ,
256344 ephemeral = ephemeral or [],
257345 to_device_messages = to_device_messages or [],
346+ one_time_key_counts = one_time_key_counts or {},
347+ unused_fallback_keys = unused_fallback_keys or {},
258348 )
259349 service_is_up = await self ._is_service_up (service )
260350 if service_is_up :
0 commit comments