From 3319099d2800f819f13b00a66d03ba3a12b30ec6 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 7 Dec 2021 12:56:12 -0500 Subject: [PATCH 1/7] Add a constant for receipt types. --- synapse/api/constants.py | 4 ++++ synapse/handlers/sync.py | 4 ++-- synapse/push/push_tools.py | 3 ++- synapse/rest/client/notifications.py | 3 ++- synapse/rest/client/read_marker.py | 4 ++-- synapse/rest/client/receipts.py | 4 ++-- synapse/storage/databases/main/receipts.py | 7 ++++--- 7 files changed, 18 insertions(+), 11 deletions(-) diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f7d29b431936..52c083a20b9c 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -253,5 +253,9 @@ class GuestAccess: FORBIDDEN: Final = "forbidden" +class ReceiptTypes: + READ: Final = "m.read" + + class ReadReceiptEventFields: MSC2285_HIDDEN: Final = "org.matrix.msc2285.hidden" diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index f3039c3c3fb7..96f37e9f4204 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -28,7 +28,7 @@ import attr from prometheus_client import Counter -from synapse.api.constants import AccountDataTypes, EventTypes, Membership +from synapse.api.constants import AccountDataTypes, EventTypes, Membership, ReceiptTypes from synapse.api.filtering import FilterCollection from synapse.api.presence import UserPresenceState from synapse.api.room_versions import KNOWN_ROOM_VERSIONS @@ -1046,7 +1046,7 @@ async def unread_notifs_for_room_id( last_unread_event_id = await self.store.get_last_receipt_event_id_for_user( user_id=sync_config.user.to_string(), room_id=room_id, - receipt_type="m.read", + receipt_type=ReceiptTypes.READ, ) notifs = await self.store.get_unread_event_push_actions_by_room_for_user( diff --git a/synapse/push/push_tools.py b/synapse/push/push_tools.py index 9c85200c0fb4..da641aca477c 100644 --- a/synapse/push/push_tools.py +++ b/synapse/push/push_tools.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import Dict +from synapse.api.constants import ReceiptTypes from synapse.events import EventBase from synapse.push.presentable_names import calculate_room_name, name_from_member_event from synapse.storage import Storage @@ -23,7 +24,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) - invites = await store.get_invited_rooms_for_local_user(user_id) joins = await store.get_rooms_for_user(user_id) - my_receipts_by_room = await store.get_receipts_for_user(user_id, "m.read") + my_receipts_by_room = await store.get_receipts_for_user(user_id, ReceiptTypes.READ) badge = len(invites) diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py index d1d8a984c630..b12a332776e4 100644 --- a/synapse/rest/client/notifications.py +++ b/synapse/rest/client/notifications.py @@ -15,6 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple +from synapse.api.constants import ReceiptTypes from synapse.events.utils import format_event_for_client_v2_without_room_id from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_integer, parse_string @@ -54,7 +55,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: ) receipts_by_room = await self.store.get_receipts_for_user_with_orderings( - user_id, "m.read" + user_id, ReceiptTypes.READ ) notif_event_ids = [pa["event_id"] for pa in push_actions] diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index 43c04fac6fdb..adb4df1aa5ef 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -15,7 +15,7 @@ import logging from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http.server import HttpServer from synapse.http.servlet import RestServlet, parse_json_object_from_request @@ -62,7 +62,7 @@ async def on_POST( if read_event_id: await self.receipts_handler.received_client_receipt( room_id, - "m.read", + ReceiptTypes.READ, user_id=requester.user.to_string(), event_id=read_event_id, hidden=hidden, diff --git a/synapse/rest/client/receipts.py b/synapse/rest/client/receipts.py index 2b25b9aad6a3..b24ad2d1be13 100644 --- a/synapse/rest/client/receipts.py +++ b/synapse/rest/client/receipts.py @@ -16,7 +16,7 @@ import re from typing import TYPE_CHECKING, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.api.errors import Codes, SynapseError from synapse.http import get_request_user_agent from synapse.http.server import HttpServer @@ -53,7 +53,7 @@ async def on_POST( ) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) - if receipt_type != "m.read": + if receipt_type != ReceiptTypes.READ: raise SynapseError(400, "Receipt type must be 'm.read'") # Do not allow older SchildiChat and Element Android clients (prior to Element/1.[012].x) to send an empty body. diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index c99f8aebdbdd..8df75f9c74b1 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -18,6 +18,7 @@ from twisted.internet import defer +from synapse.api.constants import ReceiptTypes from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -88,7 +89,7 @@ def get_max_receipt_stream_id(self): @cached() async def get_users_with_read_receipts_in_room(self, room_id): - receipts = await self.get_receipts_for_room(room_id, "m.read") + receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ) return {r["user_id"] for r in receipts} @cached(num_args=2) @@ -447,7 +448,7 @@ def get_all_updated_receipts_txn(txn): def _invalidate_get_users_with_receipts_in_room( self, room_id: str, receipt_type: str, user_id: str ): - if receipt_type != "m.read": + if receipt_type != ReceiptTypes.READ: return res = self.get_users_with_read_receipts_in_room.cache.get_immediate( @@ -550,7 +551,7 @@ def insert_linearized_receipt_txn( lock=False, ) - if receipt_type == "m.read" and stream_ordering is not None: + if receipt_type == ReceiptTypes.READ and stream_ordering is not None: self._remove_old_push_actions_before_txn( txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering ) From ca4adc681d558e518390c0b1a2dc109c6b41c1a5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 7 Dec 2021 13:06:39 -0500 Subject: [PATCH 2/7] Add missing type hints to receipt store. --- synapse/storage/databases/main/receipts.py | 70 ++++++++++++++++------ 1 file changed, 53 insertions(+), 17 deletions(-) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 8df75f9c74b1..67889f9d9df2 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -14,7 +14,17 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Collection, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, +) from twisted.internet import defer @@ -79,7 +89,7 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): "ReceiptsRoomChangeCache", self.get_max_receipt_stream_id() ) - def get_max_receipt_stream_id(self): + def get_max_receipt_stream_id(self) -> int: """Get the current max stream ID for receipts stream Returns: @@ -88,7 +98,7 @@ def get_max_receipt_stream_id(self): return self._receipts_id_gen.get_current_token() @cached() - async def get_users_with_read_receipts_in_room(self, room_id): + async def get_users_with_read_receipts_in_room(self, room_id: str) -> Set[str]: receipts = await self.get_receipts_for_room(room_id, ReceiptTypes.READ) return {r["user_id"] for r in receipts} @@ -120,7 +130,9 @@ async def get_last_receipt_event_id_for_user( ) @cached(num_args=2) - async def get_receipts_for_user(self, user_id, receipt_type): + async def get_receipts_for_user( + self, user_id: str, receipt_type: str + ) -> Dict[str, str]: rows = await self.db_pool.simple_select_list( table="receipts_linearized", keyvalues={"user_id": user_id, "receipt_type": receipt_type}, @@ -130,8 +142,10 @@ async def get_receipts_for_user(self, user_id, receipt_type): return {row["room_id"]: row["event_id"] for row in rows} - async def get_receipts_for_user_with_orderings(self, user_id, receipt_type): - def f(txn): + async def get_receipts_for_user_with_orderings( + self, user_id: str, receipt_type: str + ) -> JsonDict: + def f(txn) -> List[Tuple[str, str, int, int]]: sql = ( "SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" @@ -251,11 +265,13 @@ def f(txn): list_name="room_ids", num_args=3, ) - async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None): + async def _get_linearized_receipts_for_rooms( + self, room_ids: Collection[str], to_key: int, from_key: Optional[int] = None + ) -> Dict[str, List[JsonDict]]: if not room_ids: return {} - def f(txn): + def f(txn) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -447,7 +463,7 @@ def get_all_updated_receipts_txn(txn): def _invalidate_get_users_with_receipts_in_room( self, room_id: str, receipt_type: str, user_id: str - ): + ) -> None: if receipt_type != ReceiptTypes.READ: return @@ -462,7 +478,9 @@ def _invalidate_get_users_with_receipts_in_room( self.get_users_with_read_receipts_in_room.invalidate((room_id,)) - def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id): + def invalidate_caches_for_receipt( + self, room_id: str, receipt_type: str, user_id: str + ) -> None: self.get_receipts_for_user.invalidate((user_id, receipt_type)) self._get_linearized_receipts_for_room.invalidate((room_id,)) self.get_last_receipt_event_id_for_user.invalidate( @@ -483,8 +501,15 @@ def process_replication_rows(self, stream_name, instance_name, token, rows): return super().process_replication_rows(stream_name, instance_name, token, rows) def insert_linearized_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_id, data, stream_id - ): + self, + txn, + room_id: str, + receipt_type: str, + user_id: str, + event_id: str, + data: JsonDict, + stream_id: int, + ) -> Optional[int]: """Inserts a read-receipt into the database if it's newer than the current RR Returns: int|None @@ -635,11 +660,16 @@ def graph_to_linear(txn): return stream_id, max_persisted_id async def insert_graph_receipt( - self, room_id, receipt_type, user_id, event_ids, data - ): + self, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts - return await self.db_pool.runInteraction( + await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, @@ -650,8 +680,14 @@ async def insert_graph_receipt( ) def insert_graph_receipt_txn( - self, txn, room_id, receipt_type, user_id, event_ids, data - ): + self, + txn, + room_id: str, + receipt_type: str, + user_id: str, + event_ids: List[str], + data: JsonDict, + ) -> None: assert self._can_write_to_receipts txn.call_after(self.get_receipts_for_room.invalidate, (room_id, receipt_type)) From b6e96c4424cb1fdef496304aa63590832d962d7b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 7 Dec 2021 13:14:36 -0500 Subject: [PATCH 3/7] Add types to transactions. --- synapse/storage/databases/main/receipts.py | 24 ++++++++++++---------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 67889f9d9df2..40b5a5755d43 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -32,7 +32,7 @@ from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.replication.tcp.streams import ReceiptsStream from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause -from synapse.storage.database import DatabasePool +from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator from synapse.types import JsonDict @@ -145,7 +145,7 @@ async def get_receipts_for_user( async def get_receipts_for_user_with_orderings( self, user_id: str, receipt_type: str ) -> JsonDict: - def f(txn) -> List[Tuple[str, str, int, int]]: + def f(txn: LoggingTransaction) -> List[Tuple[str, str, int, int]]: sql = ( "SELECT rl.room_id, rl.event_id," " e.topological_ordering, e.stream_ordering" @@ -224,10 +224,10 @@ async def get_linearized_receipts_for_room( @cached(num_args=3, tree=True) async def _get_linearized_receipts_for_room( self, room_id: str, to_key: int, from_key: Optional[int] = None - ) -> List[dict]: + ) -> List[JsonDict]: """See get_linearized_receipts_for_room""" - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = ( "SELECT * FROM receipts_linearized WHERE" @@ -271,7 +271,7 @@ async def _get_linearized_receipts_for_rooms( if not room_ids: return {} - def f(txn) -> List[Dict[str, Any]]: + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -340,7 +340,7 @@ async def get_linearized_receipts_for_all_rooms( A dictionary of roomids to a list of receipts. """ - def f(txn): + def f(txn: LoggingTransaction) -> List[Dict[str, Any]]: if from_key: sql = """ SELECT * FROM receipts_linearized WHERE @@ -396,7 +396,7 @@ async def get_users_sent_receipts_between( if last_id == current_id: return defer.succeed([]) - def _get_users_sent_receipts_between_txn(txn): + def _get_users_sent_receipts_between_txn(txn: LoggingTransaction) -> List[str]: sql = """ SELECT DISTINCT user_id FROM receipts_linearized WHERE ? < stream_id AND stream_id <= ? @@ -436,7 +436,9 @@ async def get_all_updated_receipts( if last_id == current_id: return [], current_id, False - def get_all_updated_receipts_txn(txn): + def get_all_updated_receipts_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[int, list]], int, bool]: sql = """ SELECT stream_id, room_id, receipt_type, user_id, event_id, data FROM receipts_linearized @@ -502,7 +504,7 @@ def process_replication_rows(self, stream_name, instance_name, token, rows): def insert_linearized_receipt_txn( self, - txn, + txn: LoggingTransaction, room_id: str, receipt_type: str, user_id: str, @@ -606,7 +608,7 @@ async def insert_receipt( else: # we need to points in graph -> linearized form. # TODO: Make this better. - def graph_to_linear(txn): + def graph_to_linear(txn: LoggingTransaction) -> str: clause, args = make_in_list_sql_clause( self.database_engine, "event_id", event_ids ) @@ -681,7 +683,7 @@ async def insert_graph_receipt( def insert_graph_receipt_txn( self, - txn, + txn: LoggingTransaction, room_id: str, receipt_type: str, user_id: str, From 086adf55f9b3624853977eb470fb6b61faec32ce Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 7 Dec 2021 13:16:03 -0500 Subject: [PATCH 4/7] Newsfragment --- changelog.d/11531.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/11531.misc diff --git a/changelog.d/11531.misc b/changelog.d/11531.misc new file mode 100644 index 000000000000..ed6ef3bb3e56 --- /dev/null +++ b/changelog.d/11531.misc @@ -0,0 +1 @@ +Add a receipt types constant for `m.read`. From 2a876828b8e81ddd4041500a8f4b063ca9083ad9 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Dec 2021 10:29:00 -0500 Subject: [PATCH 5/7] Use constant in more places. --- synapse/handlers/receipts.py | 6 +++--- synapse/rest/client/read_marker.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py index 4911a1153519..5cb1ff749d92 100644 --- a/synapse/handlers/receipts.py +++ b/synapse/handlers/receipts.py @@ -14,7 +14,7 @@ import logging from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple -from synapse.api.constants import ReadReceiptEventFields +from synapse.api.constants import ReadReceiptEventFields, ReceiptTypes from synapse.appservice import ApplicationService from synapse.streams import EventSource from synapse.types import JsonDict, ReadReceipt, UserID, get_domain_from_id @@ -178,7 +178,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]: for event_id in content.keys(): event_content = content.get(event_id, {}) - m_read = event_content.get("m.read", {}) + m_read = event_content.get(ReceiptTypes.READ, {}) # If m_read is missing copy over the original event_content as there is nothing to process here if not m_read: @@ -206,7 +206,7 @@ def filter_out_hidden(events: List[JsonDict], user_id: str) -> List[JsonDict]: # Set new users unless empty if len(new_users.keys()) > 0: - new_event["content"][event_id] = {"m.read": new_users} + new_event["content"][event_id] = {ReceiptTypes.READ: new_users} # Append new_event to visible_events unless empty if len(new_event["content"].keys()) > 0: diff --git a/synapse/rest/client/read_marker.py b/synapse/rest/client/read_marker.py index adb4df1aa5ef..f51be511d1f4 100644 --- a/synapse/rest/client/read_marker.py +++ b/synapse/rest/client/read_marker.py @@ -48,7 +48,7 @@ async def on_POST( await self.presence_handler.bump_presence_active_time(requester.user) body = parse_json_object_from_request(request) - read_event_id = body.get("m.read", None) + read_event_id = body.get(ReceiptTypes.READ, None) hidden = body.get(ReadReceiptEventFields.MSC2285_HIDDEN, False) if not isinstance(hidden, bool): From cd21bcae9a1bd8b80fe8a712d5286ef1d0ed6371 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Dec 2021 10:30:54 -0500 Subject: [PATCH 6/7] Remove some types from comments. --- synapse/storage/databases/main/receipts.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 40b5a5755d43..f5de43fcfd72 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -91,9 +91,6 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): def get_max_receipt_stream_id(self) -> int: """Get the current max stream ID for receipts stream - - Returns: - int """ return self._receipts_id_gen.get_current_token() @@ -514,7 +511,7 @@ def insert_linearized_receipt_txn( ) -> Optional[int]: """Inserts a read-receipt into the database if it's newer than the current RR - Returns: int|None + Returns: None if the RR is older than the current RR otherwise, the rx timestamp of the event that the RR corresponds to (or 0 if the event is unknown) From 764070f3735490a346ed7adc232af0549c1e929a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 8 Dec 2021 10:35:45 -0500 Subject: [PATCH 7/7] Lint --- synapse/storage/databases/main/receipts.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index f5de43fcfd72..9c5625c8bbb8 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -90,8 +90,7 @@ def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"): ) def get_max_receipt_stream_id(self) -> int: - """Get the current max stream ID for receipts stream - """ + """Get the current max stream ID for receipts stream""" return self._receipts_id_gen.get_current_token() @cached()