1414# limitations under the License.
1515
1616import logging
17- from typing import TYPE_CHECKING , Any , Dict , Iterable , List , Optional , Tuple
17+ from typing import (
18+ TYPE_CHECKING ,
19+ Any ,
20+ Collection ,
21+ Dict ,
22+ Iterable ,
23+ List ,
24+ Optional ,
25+ Set ,
26+ Tuple ,
27+ )
1828
1929from twisted .internet import defer
2030
31+ from synapse .api .constants import ReceiptTypes
2132from synapse .replication .slave .storage ._slaved_id_tracker import SlavedIdTracker
2233from synapse .replication .tcp .streams import ReceiptsStream
2334from synapse .storage ._base import SQLBaseStore , db_to_json , make_in_list_sql_clause
24- from synapse .storage .database import DatabasePool
35+ from synapse .storage .database import DatabasePool , LoggingTransaction
2536from synapse .storage .engines import PostgresEngine
2637from synapse .storage .util .id_generators import MultiWriterIdGenerator , StreamIdGenerator
2738from synapse .types import JsonDict
@@ -78,17 +89,13 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
7889 "ReceiptsRoomChangeCache" , self .get_max_receipt_stream_id ()
7990 )
8091
81- def get_max_receipt_stream_id (self ):
82- """Get the current max stream ID for receipts stream
83-
84- Returns:
85- int
86- """
92+ def get_max_receipt_stream_id (self ) -> int :
93+ """Get the current max stream ID for receipts stream"""
8794 return self ._receipts_id_gen .get_current_token ()
8895
8996 @cached ()
90- async def get_users_with_read_receipts_in_room (self , room_id ) :
91- receipts = await self .get_receipts_for_room (room_id , "m.read" )
97+ async def get_users_with_read_receipts_in_room (self , room_id : str ) -> Set [ str ] :
98+ receipts = await self .get_receipts_for_room (room_id , ReceiptTypes . READ )
9299 return {r ["user_id" ] for r in receipts }
93100
94101 @cached (num_args = 2 )
@@ -119,7 +126,9 @@ async def get_last_receipt_event_id_for_user(
119126 )
120127
121128 @cached (num_args = 2 )
122- async def get_receipts_for_user (self , user_id , receipt_type ):
129+ async def get_receipts_for_user (
130+ self , user_id : str , receipt_type : str
131+ ) -> Dict [str , str ]:
123132 rows = await self .db_pool .simple_select_list (
124133 table = "receipts_linearized" ,
125134 keyvalues = {"user_id" : user_id , "receipt_type" : receipt_type },
@@ -129,8 +138,10 @@ async def get_receipts_for_user(self, user_id, receipt_type):
129138
130139 return {row ["room_id" ]: row ["event_id" ] for row in rows }
131140
132- async def get_receipts_for_user_with_orderings (self , user_id , receipt_type ):
133- def f (txn ):
141+ async def get_receipts_for_user_with_orderings (
142+ self , user_id : str , receipt_type : str
143+ ) -> JsonDict :
144+ def f (txn : LoggingTransaction ) -> List [Tuple [str , str , int , int ]]:
134145 sql = (
135146 "SELECT rl.room_id, rl.event_id,"
136147 " e.topological_ordering, e.stream_ordering"
@@ -209,10 +220,10 @@ async def get_linearized_receipts_for_room(
209220 @cached (num_args = 3 , tree = True )
210221 async def _get_linearized_receipts_for_room (
211222 self , room_id : str , to_key : int , from_key : Optional [int ] = None
212- ) -> List [dict ]:
223+ ) -> List [JsonDict ]:
213224 """See get_linearized_receipts_for_room"""
214225
215- def f (txn ) :
226+ def f (txn : LoggingTransaction ) -> List [ Dict [ str , Any ]] :
216227 if from_key :
217228 sql = (
218229 "SELECT * FROM receipts_linearized WHERE"
@@ -250,11 +261,13 @@ def f(txn):
250261 list_name = "room_ids" ,
251262 num_args = 3 ,
252263 )
253- async def _get_linearized_receipts_for_rooms (self , room_ids , to_key , from_key = None ):
264+ async def _get_linearized_receipts_for_rooms (
265+ self , room_ids : Collection [str ], to_key : int , from_key : Optional [int ] = None
266+ ) -> Dict [str , List [JsonDict ]]:
254267 if not room_ids :
255268 return {}
256269
257- def f (txn ) :
270+ def f (txn : LoggingTransaction ) -> List [ Dict [ str , Any ]] :
258271 if from_key :
259272 sql = """
260273 SELECT * FROM receipts_linearized WHERE
@@ -323,7 +336,7 @@ async def get_linearized_receipts_for_all_rooms(
323336 A dictionary of roomids to a list of receipts.
324337 """
325338
326- def f (txn ) :
339+ def f (txn : LoggingTransaction ) -> List [ Dict [ str , Any ]] :
327340 if from_key :
328341 sql = """
329342 SELECT * FROM receipts_linearized WHERE
@@ -379,7 +392,7 @@ async def get_users_sent_receipts_between(
379392 if last_id == current_id :
380393 return defer .succeed ([])
381394
382- def _get_users_sent_receipts_between_txn (txn ) :
395+ def _get_users_sent_receipts_between_txn (txn : LoggingTransaction ) -> List [ str ] :
383396 sql = """
384397 SELECT DISTINCT user_id FROM receipts_linearized
385398 WHERE ? < stream_id AND stream_id <= ?
@@ -419,7 +432,9 @@ async def get_all_updated_receipts(
419432 if last_id == current_id :
420433 return [], current_id , False
421434
422- def get_all_updated_receipts_txn (txn ):
435+ def get_all_updated_receipts_txn (
436+ txn : LoggingTransaction ,
437+ ) -> Tuple [List [Tuple [int , list ]], int , bool ]:
423438 sql = """
424439 SELECT stream_id, room_id, receipt_type, user_id, event_id, data
425440 FROM receipts_linearized
@@ -446,8 +461,8 @@ def get_all_updated_receipts_txn(txn):
446461
447462 def _invalidate_get_users_with_receipts_in_room (
448463 self , room_id : str , receipt_type : str , user_id : str
449- ):
450- if receipt_type != "m.read" :
464+ ) -> None :
465+ if receipt_type != ReceiptTypes . READ :
451466 return
452467
453468 res = self .get_users_with_read_receipts_in_room .cache .get_immediate (
@@ -461,7 +476,9 @@ def _invalidate_get_users_with_receipts_in_room(
461476
462477 self .get_users_with_read_receipts_in_room .invalidate ((room_id ,))
463478
464- def invalidate_caches_for_receipt (self , room_id , receipt_type , user_id ):
479+ def invalidate_caches_for_receipt (
480+ self , room_id : str , receipt_type : str , user_id : str
481+ ) -> None :
465482 self .get_receipts_for_user .invalidate ((user_id , receipt_type ))
466483 self ._get_linearized_receipts_for_room .invalidate ((room_id ,))
467484 self .get_last_receipt_event_id_for_user .invalidate (
@@ -482,11 +499,18 @@ def process_replication_rows(self, stream_name, instance_name, token, rows):
482499 return super ().process_replication_rows (stream_name , instance_name , token , rows )
483500
484501 def insert_linearized_receipt_txn (
485- self , txn , room_id , receipt_type , user_id , event_id , data , stream_id
486- ):
502+ self ,
503+ txn : LoggingTransaction ,
504+ room_id : str ,
505+ receipt_type : str ,
506+ user_id : str ,
507+ event_id : str ,
508+ data : JsonDict ,
509+ stream_id : int ,
510+ ) -> Optional [int ]:
487511 """Inserts a read-receipt into the database if it's newer than the current RR
488512
489- Returns: int|None
513+ Returns:
490514 None if the RR is older than the current RR
491515 otherwise, the rx timestamp of the event that the RR corresponds to
492516 (or 0 if the event is unknown)
@@ -550,7 +574,7 @@ def insert_linearized_receipt_txn(
550574 lock = False ,
551575 )
552576
553- if receipt_type == "m.read" and stream_ordering is not None :
577+ if receipt_type == ReceiptTypes . READ and stream_ordering is not None :
554578 self ._remove_old_push_actions_before_txn (
555579 txn , room_id = room_id , user_id = user_id , stream_ordering = stream_ordering
556580 )
@@ -580,7 +604,7 @@ async def insert_receipt(
580604 else :
581605 # we need to points in graph -> linearized form.
582606 # TODO: Make this better.
583- def graph_to_linear (txn ) :
607+ def graph_to_linear (txn : LoggingTransaction ) -> str :
584608 clause , args = make_in_list_sql_clause (
585609 self .database_engine , "event_id" , event_ids
586610 )
@@ -634,11 +658,16 @@ def graph_to_linear(txn):
634658 return stream_id , max_persisted_id
635659
636660 async def insert_graph_receipt (
637- self , room_id , receipt_type , user_id , event_ids , data
638- ):
661+ self ,
662+ room_id : str ,
663+ receipt_type : str ,
664+ user_id : str ,
665+ event_ids : List [str ],
666+ data : JsonDict ,
667+ ) -> None :
639668 assert self ._can_write_to_receipts
640669
641- return await self .db_pool .runInteraction (
670+ await self .db_pool .runInteraction (
642671 "insert_graph_receipt" ,
643672 self .insert_graph_receipt_txn ,
644673 room_id ,
@@ -649,8 +678,14 @@ async def insert_graph_receipt(
649678 )
650679
651680 def insert_graph_receipt_txn (
652- self , txn , room_id , receipt_type , user_id , event_ids , data
653- ):
681+ self ,
682+ txn : LoggingTransaction ,
683+ room_id : str ,
684+ receipt_type : str ,
685+ user_id : str ,
686+ event_ids : List [str ],
687+ data : JsonDict ,
688+ ) -> None :
654689 assert self ._can_write_to_receipts
655690
656691 txn .call_after (self .get_receipts_for_room .invalidate , (room_id , receipt_type ))
0 commit comments