Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/16240.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Delete device messages asynchronously and in staged batches using the task scheduler.
48 changes: 48 additions & 0 deletions synapse/handlers/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@
)
from synapse.types import (
JsonDict,
JsonMapping,
ScheduledTask,
StrCollection,
StreamKeyType,
StreamToken,
TaskStatus,
UserID,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
Expand All @@ -62,6 +65,7 @@

logger = logging.getLogger(__name__)

DELETE_DEVICE_MSGS_TASK_NAME = "delete_device_messages"
MAX_DEVICE_DISPLAY_NAME_LEN = 100
DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000

Expand All @@ -78,6 +82,7 @@ def __init__(self, hs: "HomeServer"):
self._appservice_handler = hs.get_application_service_handler()
self._state_storage = hs.get_storage_controllers().state
self._auth_handler = hs.get_auth_handler()
self._event_sources = hs.get_event_sources()
self.server_name = hs.hostname
self._msc3852_enabled = hs.config.experimental.msc3852_enabled
self._query_appservices_for_keys = (
Expand Down Expand Up @@ -386,6 +391,7 @@ def __init__(self, hs: "HomeServer"):
self._account_data_handler = hs.get_account_data_handler()
self._storage_controllers = hs.get_storage_controllers()
self.db_pool = hs.get_datastores().main.db_pool
self._task_scheduler = hs.get_task_scheduler()

self.device_list_updater = DeviceListUpdater(hs, self)

Expand Down Expand Up @@ -419,6 +425,10 @@ def __init__(self, hs: "HomeServer"):
self._delete_stale_devices,
)

self._task_scheduler.register_action(
self._delete_device_messages, DELETE_DEVICE_MSGS_TASK_NAME
)

def _check_device_name_length(self, name: Optional[str]) -> None:
"""
Checks whether a device name is longer than the maximum allowed length.
Expand Down Expand Up @@ -530,6 +540,7 @@ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
user_id: The user to delete devices from.
device_ids: The list of device IDs to delete
"""
to_device_stream_id = self._event_sources.get_current_token().to_device_key

try:
await self.store.delete_devices(user_id, device_ids)
Expand Down Expand Up @@ -559,12 +570,49 @@ async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
f"org.matrix.msc3890.local_notification_settings.{device_id}",
)

# Delete device messages asynchronously and in batches using the task scheduler
await self._task_scheduler.schedule_task(
DELETE_DEVICE_MSGS_TASK_NAME,
resource_id=device_id,
params={
"user_id": user_id,
"device_id": device_id,
"up_to_stream_id": to_device_stream_id,
},
)

# Pushers are deleted after `delete_access_tokens_for_user` is called so that
# modules using `on_logged_out` hook can use them if needed.
await self.hs.get_pusherpool().remove_pushers_by_devices(user_id, device_ids)

await self.notify_device_update(user_id, device_ids)

DEVICE_MSGS_DELETE_BATCH_LIMIT = 100

async def _delete_device_messages(
self,
task: ScheduledTask,
) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
"""Scheduler task to delete device messages in batch of `DEVICE_MSGS_DELETE_BATCH_LIMIT`."""
assert task.params is not None
user_id = task.params["user_id"]
device_id = task.params["device_id"]
up_to_stream_id = task.params["up_to_stream_id"]

res = await self.store.delete_messages_for_device(
user_id=user_id,
device_id=device_id,
up_to_stream_id=up_to_stream_id,
limit=DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT,
)

if res < DeviceHandler.DEVICE_MSGS_DELETE_BATCH_LIMIT:
return TaskStatus.COMPLETE, None, None
else:
# There is probably still device messages to be deleted, let's keep the task active and it will be run
# again in a subsequent scheduler loop run (probably the next one, if not too many tasks are running).
return TaskStatus.ACTIVE, None, None

async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
"""Update the given device

Expand Down
4 changes: 1 addition & 3 deletions synapse/handlers/presence.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class BasePresenceHandler(abc.ABC):
writer"""

def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
Expand Down Expand Up @@ -426,8 +427,6 @@ def __exit__(
class WorkerPresenceHandler(BasePresenceHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs

self._presence_writer_instance = hs.config.worker.writers.presence[0]

# Route presence EDUs to the right worker
Expand Down Expand Up @@ -691,7 +690,6 @@ async def bump_presence_active_time(
class PresenceHandler(BasePresenceHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.wheel_timer: WheelTimer[str] = WheelTimer()
self.notifier = hs.get_notifier()

Expand Down
16 changes: 13 additions & 3 deletions synapse/handlers/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from synapse.api.presence import UserPresenceState
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.handlers.device import DELETE_DEVICE_MSGS_TASK_NAME
from synapse.handlers.relations import BundledAggregations
from synapse.logging import issue9533_logger
from synapse.logging.context import current_context
Expand Down Expand Up @@ -268,6 +269,7 @@ def __init__(self, hs: "HomeServer"):
self._storage_controllers = hs.get_storage_controllers()
self._state_storage_controller = self._storage_controllers.state
self._device_handler = hs.get_device_handler()
self._task_scheduler = hs.get_task_scheduler()

self.should_calculate_push_rules = hs.config.push.enable_push

Expand Down Expand Up @@ -360,11 +362,19 @@ async def _wait_for_sync_for_user(
# (since we now know that the device has received them)
if since_token is not None:
since_stream_id = since_token.to_device_key
deleted = await self.store.delete_messages_for_device(
sync_config.user.to_string(), sync_config.device_id, since_stream_id
# Delete device messages asynchronously and in batches using the task scheduler
await self._task_scheduler.schedule_task(
DELETE_DEVICE_MSGS_TASK_NAME,
resource_id=sync_config.device_id,
params={
"user_id": sync_config.user.to_string(),
"device_id": sync_config.device_id,
"up_to_stream_id": since_stream_id,
},
)
logger.debug(
"Deleted %d to-device messages up to %d", deleted, since_stream_id
"Deletion of to-device messages up to %d scheduled",
since_stream_id,
)

if timeout == 0 or since_token is None or full_state:
Expand Down
26 changes: 20 additions & 6 deletions synapse/storage/databases/main/deviceinbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,13 +445,18 @@ def get_device_messages_txn(

@trace
async def delete_messages_for_device(
self, user_id: str, device_id: Optional[str], up_to_stream_id: int
self,
user_id: str,
device_id: Optional[str],
up_to_stream_id: int,
limit: int,
) -> int:
"""
Args:
user_id: The recipient user_id.
device_id: The recipient device_id.
up_to_stream_id: Where to delete messages up to.
limit: maximum number of messages to delete

Returns:
The number of messages deleted.
Expand All @@ -472,12 +477,16 @@ async def delete_messages_for_device(
log_kv({"message": "No changes in cache since last check"})
return 0

ROW_ID_NAME = self.database_engine.row_id_name

def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
sql = (
"DELETE FROM device_inbox"
" WHERE user_id = ? AND device_id = ?"
" AND stream_id <= ?"
)
sql = f"""
DELETE FROM device_inbox WHERE {ROW_ID_NAME} IN (
SELECT {ROW_ID_NAME} FROM device_inbox
WHERE user_id = ? AND device_id = ? AND stream_id <= ?
LIMIT {limit}
)
"""
txn.execute(sql, (user_id, device_id, up_to_stream_id))
return txn.rowcount

Expand All @@ -487,6 +496,11 @@ def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:

log_kv({"message": f"deleted {count} messages for device", "count": count})

# In this case we don't know if we hit the limit or the delete is complete
# so let's not update the cache.
if count == limit:
return count

# Update the cache, ensuring that we only ever increase the value
updated_last_deleted_stream_id = self._last_device_delete_cache.get(
(user_id, device_id), 0
Expand Down
8 changes: 0 additions & 8 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1766,14 +1766,6 @@ def _delete_devices_txn(txn: LoggingTransaction) -> None:
keyvalues={"user_id": user_id, "hidden": False},
)

self.db_pool.simple_delete_many_txn(
txn,
table="device_inbox",
column="device_id",
values=device_ids,
keyvalues={"user_id": user_id},
)

self.db_pool.simple_delete_many_txn(
txn,
table="device_auth_providers",
Expand Down
6 changes: 1 addition & 5 deletions synapse/storage/databases/main/receipts.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,11 +939,7 @@ async def _background_receipts_linearized_unique_index(
receipts."""

def _remote_duplicate_receipts_txn(txn: LoggingTransaction) -> None:
if isinstance(self.database_engine, PostgresEngine):
ROW_ID_NAME = "ctid"
else:
ROW_ID_NAME = "rowid"

ROW_ID_NAME = self.database_engine.row_id_name
# Identify any duplicate receipts arising from
# https://github.com/matrix-org/synapse/issues/14406.
# The following query takes less than a minute on matrix.org.
Expand Down
6 changes: 6 additions & 0 deletions synapse/storage/engines/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ def server_version(self) -> str:
"""Gets a string giving the server version. For example: '3.22.0'"""
...

@property
@abc.abstractmethod
def row_id_name(self) -> str:
"""Gets the literal name representing a row id for this engine."""
...

@abc.abstractmethod
def in_transaction(self, conn: ConnectionType) -> bool:
"""Whether the connection is currently in a transaction."""
Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/engines/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,10 @@ def server_version(self) -> str:
else:
return "%i.%i.%i" % (numver / 10000, (numver % 10000) / 100, numver % 100)

@property
def row_id_name(self) -> str:
return "ctid"

def in_transaction(self, conn: psycopg2.extensions.connection) -> bool:
return conn.status != psycopg2.extensions.STATUS_READY

Expand Down
4 changes: 4 additions & 0 deletions synapse/storage/engines/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,10 @@ def server_version(self) -> str:
"""Gets a string giving the server version. For example: '3.22.0'."""
return "%i.%i.%i" % sqlite3.sqlite_version_info

@property
def row_id_name(self) -> str:
return "rowid"

def in_transaction(self, conn: sqlite3.Connection) -> bool:
return conn.in_transaction

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


from synapse.storage.database import LoggingTransaction
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.engines import BaseDatabaseEngine
from synapse.storage.prepare_database import get_statements

FIX_INDEXES = """
Expand All @@ -37,7 +37,7 @@


def run_create(cur: LoggingTransaction, database_engine: BaseDatabaseEngine) -> None:
rowid = "ctid" if isinstance(database_engine, PostgresEngine) else "rowid"
rowid = database_engine.row_id_name

# remove duplicates from group_users & group_invites tables
cur.execute(
Expand Down
17 changes: 7 additions & 10 deletions synapse/util/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class TaskScheduler:
LAST_UPDATE_BEFORE_WARNING_MS = 24 * 60 * 60 * 1000 # 24hrs

def __init__(self, hs: "HomeServer"):
self._hs = hs
self._store = hs.get_datastores().main
self._clock = hs.get_clock()
self._running_tasks: Set[str] = set()
Expand All @@ -97,8 +98,6 @@ def __init__(self, hs: "HomeServer"):
"handle_scheduled_tasks",
self._handle_scheduled_tasks,
)
else:
self.replication_client = hs.get_replication_command_handler()

def register_action(
self,
Expand Down Expand Up @@ -133,7 +132,7 @@ async def schedule_task(
params: Optional[JsonMapping] = None,
) -> str:
"""Schedule a new potentially resumable task. A function matching the specified
`action` should have been previously registered with `register_action`.
`action` should have be registered with `register_action` before the task is run.

Args:
action: the name of a previously registered action
Expand All @@ -149,11 +148,6 @@ async def schedule_task(
Returns:
The id of the scheduled task
"""
if action not in self._actions:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check is moved inside _launch_task instead, since actions only need to be registered on the background worker and may not be registered on other ones.

raise Exception(
f"No function associated with action {action} of the scheduled task"
)

status = TaskStatus.SCHEDULED
if timestamp is None or timestamp < self._clock.time_msec():
timestamp = self._clock.time_msec()
Expand All @@ -175,7 +169,7 @@ async def schedule_task(
if self._run_background_tasks:
await self._launch_task(task)
else:
self.replication_client.send_new_active_task(task.id)
self._hs.get_replication_command_handler().send_new_active_task(task.id)

return task.id

Expand Down Expand Up @@ -315,7 +309,10 @@ async def _launch_task(self, task: ScheduledTask) -> None:
"""
assert self._run_background_tasks

assert task.action in self._actions
if task.action not in self._actions:
raise Exception(
f"No function associated with action {task.action} of the scheduled task {task.id}"
)
function = self._actions[task.action]

async def wrapper() -> None:
Expand Down
Loading