Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Claim local one-time-keys in bulk #16565

Merged
merged 15 commits into from
Oct 30, 2023
Prev Previous commit
Next Next commit
Bulk claim OTKs
  • Loading branch information
David Robertson committed Oct 28, 2023
commit da695381a57f97f61466a05b7e9ca17fb11ae0bc
110 changes: 62 additions & 48 deletions synapse/storage/databases/main/end_to_end_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
Expand Down Expand Up @@ -1133,25 +1134,31 @@ async def claim_e2e_one_time_keys(
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
# allows us to use autocommit mode.
unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {}
for user_id, device_id, algorithm, count in query_list:
claim_rows = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
self._claim_e2e_one_time_key_returning,
user_id,
device_id,
algorithm,
count,
db_autocommit=True,
unfulfilled_claim_counts[user_id, device_id, algorithm] = count

bulk_claims = await self.db_pool.runInteraction(
"claim_e2e_one_time_keys",
self._claim_e2e_one_time_keys_returning,
query_list,
db_autocommit=True,
)

for user_id, device_id, algorithm, key_id, key_json in bulk_claims:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
if claim_rows:
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
for claim_row in claim_rows:
device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
# Did we get enough OTKs?
count -= len(claim_rows)
if count:
device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1

# Did we get enough OTKs?
for (
user_id,
device_id,
algorithm,
), count in unfulfilled_claim_counts.items():
if count > 0:
missing.append((user_id, device_id, algorithm, count))
else:
for user_id, device_id, algorithm, count in query_list:
Expand Down Expand Up @@ -1276,46 +1283,53 @@ def _claim_e2e_one_time_key_simple(
return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]

@trace
def _claim_e2e_one_time_key_returning(
def _claim_e2e_one_time_keys_returning(
self,
txn: LoggingTransaction,
user_id: str,
device_id: str,
algorithm: str,
count: int,
) -> List[Tuple[str, str]]:
"""Claim OTK for device for DBs that support RETURNING.
query_list: Iterable[Tuple[str, str, str, int]],
) -> List[Tuple[str, str, str, str, str]]:
"""Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING.

Args:
query_list: Collection of tuples (user_id, device_id, algorithm, count)
as passed to claim_e2e_one_time_keys.

Returns:
A tuple of key name (algorithm + key ID) and key JSON, if an
OTK was found.
A list of tuples (user_id, device_id, algorithm, key_id, key_json)
for each OTK claimed.
"""

# We can use RETURNING to do the fetch and DELETE in once step.
sql = """
DELETE FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
AND key_id IN (
SELECT key_id FROM e2e_one_time_keys_json
WHERE user_id = ? AND device_id = ? AND algorithm = ?
LIMIT ?
)
RETURNING key_id, key_json
"""

txn.execute(
sql,
(user_id, device_id, algorithm, user_id, device_id, algorithm, count),
WITH claims(user_id, device_id, algorithm, claim_count) AS (
VALUES ?
), ranked_keys AS (
SELECT
user_id, device_id, algorithm, key_id, claim_count,
ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r
FROM e2e_one_time_keys_json
JOIN claims USING (user_id, device_id, algorithm)
)
DELETE FROM e2e_one_time_keys_json k
WHERE (user_id, device_id, algorithm, key_id) IN (
SELECT user_id, device_id, algorithm, key_id
FROM ranked_keys
WHERE r <= claim_count
)
RETURNING user_id, device_id, algorithm, key_id, key_json;
"""
otk_rows = cast(
List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list)
)
otk_rows = list(txn)
if not otk_rows:
return []

self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
seen_user_device: Set[Tuple[str, str]] = set()
for user_id, device_id, _, _, _ in otk_rows:
if (user_id, device_id) in seen_user_device:
continue
seen_user_device.add((user_id, device_id))
self._invalidate_cache_and_stream(
txn, self.count_e2e_one_time_keys, (user_id, device_id)
)
Comment on lines +1320 to +1327
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old code never used to have this deduplication. If the same (user, device) showed up twice (either with multiple algorithms or claiming multiple keys) we'd send out more than one invalidation for that (user, device). I don't know how much this matters in practice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know how much this matters in practice.

I suspect it is just inefficient, but doesn't matter too much?


return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows]
return otk_rows


class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
Expand Down