Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Cache requests for user's devices from federation #15675

Merged
merged 5 commits into from
Jun 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15675.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Cache requests for user's devices over federation.
4 changes: 4 additions & 0 deletions synapse/storage/databases/main/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1941,6 +1941,10 @@ def _add_device_change_to_stream_txn(
user_id,
stream_ids[-1],
)
txn.call_after(
self._get_e2e_device_keys_for_federation_query_inner.invalidate,
(user_id,),
)

min_stream_id = stream_ids[0]

Expand Down
67 changes: 65 additions & 2 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import abc
from typing import (
TYPE_CHECKING,
Any,
Collection,
Dict,
Iterable,
Expand All @@ -39,6 +40,7 @@
TransactionUnusedFallbackKeys,
)
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.replication.tcp.streams._base import DeviceListsStream
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
DatabasePool,
Expand Down Expand Up @@ -104,6 +106,23 @@ def __init__(
self.hs.config.federation.allow_device_name_lookup_over_federation
)

def process_replication_rows(
self,
stream_name: str,
instance_name: str,
token: int,
rows: Iterable[Any],
) -> None:
if stream_name == DeviceListsStream.NAME:
for row in rows:
assert isinstance(row, DeviceListsStream.DeviceListsStreamRow)
if row.entity.startswith("@"):
self._get_e2e_device_keys_for_federation_query_inner.invalidate(
(row.entity,)
)

super().process_replication_rows(stream_name, instance_name, token, rows)

async def get_e2e_device_keys_for_federation_query(
self, user_id: str
) -> Tuple[int, List[JsonDict]]:
Expand All @@ -114,6 +133,50 @@ async def get_e2e_device_keys_for_federation_query(
"""
now_stream_id = self.get_device_stream_token()

# We need to be careful with the caching here, as we need to always
# return *all* persisted devices, however there may be a lag between a
# new device being persisted and the cache being invalidated.
cached_results = (
self._get_e2e_device_keys_for_federation_query_inner.cache.get_immediate(
user_id, None
)
)
if cached_results is not None:
# Check that there have been no new devices added by another worker
# after the cache. This should be quick as there should be few rows
# with a higher stream ordering.
#
# Note that we invalidate based on the device stream, so we only
# have to check for potential invalidations after the
# `now_stream_id`.
sql = """
SELECT user_id FROM device_lists_stream
WHERE stream_id >= ? AND user_id = ?
"""
rows = await self.db_pool.execute(
"get_e2e_device_keys_for_federation_query_check",
None,
sql,
now_stream_id,
user_id,
)
if not rows:
# No new rows, so cache is still valid.
return now_stream_id, cached_results

# There has, so let's invalidate the cache and run the query.
self._get_e2e_device_keys_for_federation_query_inner.invalidate((user_id,))

results = await self._get_e2e_device_keys_for_federation_query_inner(user_id)

return now_stream_id, results

@cached(iterable=True)
async def _get_e2e_device_keys_for_federation_query_inner(
self, user_id: str
) -> List[JsonDict]:
"""Get all devices (with any device keys) for a user"""

devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])

if devices:
Expand All @@ -134,9 +197,9 @@ async def get_e2e_device_keys_for_federation_query(

results.append(result)

return now_stream_id, results
return results

return now_stream_id, []
return []

@trace
@cancellable
Expand Down