From 651e5a23b5dca735ce67673c889cc231340cdd63 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 27 Oct 2023 18:53:06 +0100 Subject: [PATCH 01/15] Duplicate the two code paths --- .../storage/databases/main/end_to_end_keys.py | 85 +++++++++++++------ 1 file changed, 58 insertions(+), 27 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index f70f95eebaa5..efba5f199e3e 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1216,35 +1216,66 @@ def _claim_e2e_one_time_key_returning( results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} missing: List[Tuple[str, str, str, int]] = [] - for user_id, device_id, algorithm, count in query_list: - if self.database_engine.supports_returning: - # If we support RETURNING clause we can use a single query that - # allows us to use autocommit mode. - _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning - db_autocommit = True - else: - _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple - db_autocommit = False + if self.database_engine.supports_returning: + for user_id, device_id, algorithm, count in query_list: + if self.database_engine.supports_returning: + # If we support RETURNING clause we can use a single query that + # allows us to use autocommit mode. + _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning + db_autocommit = True + else: + _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple + db_autocommit = False - claim_rows = await self.db_pool.runInteraction( - "claim_e2e_one_time_keys", - _claim_e2e_one_time_key, - user_id, - device_id, - algorithm, - count, - db_autocommit=db_autocommit, - ) - if claim_rows: - device_results = results.setdefault(user_id, {}).setdefault( - device_id, {} + claim_rows = await self.db_pool.runInteraction( + "claim_e2e_one_time_keys", + _claim_e2e_one_time_key, + user_id, + device_id, + algorithm, + count, + db_autocommit=db_autocommit, ) - for claim_row in claim_rows: - device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) - # Did we get enough OTKs? - count -= len(claim_rows) - if count: - missing.append((user_id, device_id, algorithm, count)) + if claim_rows: + device_results = results.setdefault(user_id, {}).setdefault( + device_id, {} + ) + for claim_row in claim_rows: + device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) + # Did we get enough OTKs? + count -= len(claim_rows) + if count: + missing.append((user_id, device_id, algorithm, count)) + else: + for user_id, device_id, algorithm, count in query_list: + if self.database_engine.supports_returning: + # If we support RETURNING clause we can use a single query that + # allows us to use autocommit mode. + _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning + db_autocommit = True + else: + _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple + db_autocommit = False + + claim_rows = await self.db_pool.runInteraction( + "claim_e2e_one_time_keys", + _claim_e2e_one_time_key, + user_id, + device_id, + algorithm, + count, + db_autocommit=db_autocommit, + ) + if claim_rows: + device_results = results.setdefault(user_id, {}).setdefault( + device_id, {} + ) + for claim_row in claim_rows: + device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) + # Did we get enough OTKs? + count -= len(claim_rows) + if count: + missing.append((user_id, device_id, algorithm, count)) return results, missing From cabe11352a9168f6c2c178ef09588af5074a3201 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 27 Oct 2023 18:54:43 +0100 Subject: [PATCH 02/15] Simplify --- .../storage/databases/main/end_to_end_keys.py | 28 ++++--------------- 1 file changed, 6 insertions(+), 22 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index efba5f199e3e..0fae022bb87f 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1217,24 +1217,17 @@ def _claim_e2e_one_time_key_returning( results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} missing: List[Tuple[str, str, str, int]] = [] if self.database_engine.supports_returning: + # If we support RETURNING clause we can use a single query that + # allows us to use autocommit mode. for user_id, device_id, algorithm, count in query_list: - if self.database_engine.supports_returning: - # If we support RETURNING clause we can use a single query that - # allows us to use autocommit mode. - _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning - db_autocommit = True - else: - _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple - db_autocommit = False - claim_rows = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", - _claim_e2e_one_time_key, + _claim_e2e_one_time_key_returning, user_id, device_id, algorithm, count, - db_autocommit=db_autocommit, + db_autocommit=True, ) if claim_rows: device_results = results.setdefault(user_id, {}).setdefault( @@ -1248,23 +1241,14 @@ def _claim_e2e_one_time_key_returning( missing.append((user_id, device_id, algorithm, count)) else: for user_id, device_id, algorithm, count in query_list: - if self.database_engine.supports_returning: - # If we support RETURNING clause we can use a single query that - # allows us to use autocommit mode. - _claim_e2e_one_time_key = _claim_e2e_one_time_key_returning - db_autocommit = True - else: - _claim_e2e_one_time_key = _claim_e2e_one_time_key_simple - db_autocommit = False - claim_rows = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", - _claim_e2e_one_time_key, + _claim_e2e_one_time_key_simple, user_id, device_id, algorithm, count, - db_autocommit=db_autocommit, + db_autocommit=False, ) if claim_rows: device_results = results.setdefault(user_id, {}).setdefault( From bcb2ba58c5d226b2ac958b14527025e33b9b01be Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 27 Oct 2023 18:58:15 +0100 Subject: [PATCH 03/15] Pull out the helpers to methods Otherwise symbols like `count` start to clash. Also my brain hurts. --- .../storage/databases/main/end_to_end_keys.py | 178 +++++++++--------- 1 file changed, 88 insertions(+), 90 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 0fae022bb87f..97d94d0b67cb 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1126,94 +1126,6 @@ async def claim_e2e_one_time_keys( A copy of the input which has not been fulfilled. """ - @trace - def _claim_e2e_one_time_key_simple( - txn: LoggingTransaction, - user_id: str, - device_id: str, - algorithm: str, - count: int, - ) -> List[Tuple[str, str]]: - """Claim OTK for device for DBs that don't support RETURNING. - - Returns: - A tuple of key name (algorithm + key ID) and key JSON, if an - OTK was found. - """ - - sql = """ - SELECT key_id, key_json FROM e2e_one_time_keys_json - WHERE user_id = ? AND device_id = ? AND algorithm = ? - LIMIT ? - """ - - txn.execute(sql, (user_id, device_id, algorithm, count)) - otk_rows = list(txn) - if not otk_rows: - return [] - - self.db_pool.simple_delete_many_txn( - txn, - table="e2e_one_time_keys_json", - column="key_id", - values=[otk_row[0] for otk_row in otk_rows], - keyvalues={ - "user_id": user_id, - "device_id": device_id, - "algorithm": algorithm, - }, - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - - return [ - (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows - ] - - @trace - def _claim_e2e_one_time_key_returning( - txn: LoggingTransaction, - user_id: str, - device_id: str, - algorithm: str, - count: int, - ) -> List[Tuple[str, str]]: - """Claim OTK for device for DBs that support RETURNING. - - Returns: - A tuple of key name (algorithm + key ID) and key JSON, if an - OTK was found. - """ - - # We can use RETURNING to do the fetch and DELETE in once step. - sql = """ - DELETE FROM e2e_one_time_keys_json - WHERE user_id = ? AND device_id = ? AND algorithm = ? - AND key_id IN ( - SELECT key_id FROM e2e_one_time_keys_json - WHERE user_id = ? AND device_id = ? AND algorithm = ? - LIMIT ? - ) - RETURNING key_id, key_json - """ - - txn.execute( - sql, - (user_id, device_id, algorithm, user_id, device_id, algorithm, count), - ) - otk_rows = list(txn) - if not otk_rows: - return [] - - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - - return [ - (f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows - ] - results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} missing: List[Tuple[str, str, str, int]] = [] if self.database_engine.supports_returning: @@ -1222,7 +1134,7 @@ def _claim_e2e_one_time_key_returning( for user_id, device_id, algorithm, count in query_list: claim_rows = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", - _claim_e2e_one_time_key_returning, + self._claim_e2e_one_time_key_returning, user_id, device_id, algorithm, @@ -1243,7 +1155,7 @@ def _claim_e2e_one_time_key_returning( for user_id, device_id, algorithm, count in query_list: claim_rows = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", - _claim_e2e_one_time_key_simple, + self._claim_e2e_one_time_key_simple, user_id, device_id, algorithm, @@ -1317,6 +1229,92 @@ async def claim_e2e_fallback_keys( return results + @trace + def _claim_e2e_one_time_key_simple( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + algorithm: str, + count: int, + ) -> List[Tuple[str, str]]: + """Claim OTK for device for DBs that don't support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + sql = """ + SELECT key_id, key_json FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT ? + """ + + txn.execute(sql, (user_id, device_id, algorithm, count)) + otk_rows = list(txn) + if not otk_rows: + return [] + + self.db_pool.simple_delete_many_txn( + txn, + table="e2e_one_time_keys_json", + column="key_id", + values=[otk_row[0] for otk_row in otk_rows], + keyvalues={ + "user_id": user_id, + "device_id": device_id, + "algorithm": algorithm, + }, + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + + return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows] + + @trace + def _claim_e2e_one_time_key_returning( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + algorithm: str, + count: int, + ) -> List[Tuple[str, str]]: + """Claim OTK for device for DBs that support RETURNING. + + Returns: + A tuple of key name (algorithm + key ID) and key JSON, if an + OTK was found. + """ + + # We can use RETURNING to do the fetch and DELETE in once step. + sql = """ + DELETE FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + AND key_id IN ( + SELECT key_id FROM e2e_one_time_keys_json + WHERE user_id = ? AND device_id = ? AND algorithm = ? + LIMIT ? + ) + RETURNING key_id, key_json + """ + + txn.execute( + sql, + (user_id, device_id, algorithm, user_id, device_id, algorithm, count), + ) + otk_rows = list(txn) + if not otk_rows: + return [] + + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) + + return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows] + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def __init__( From 49fa421bbb59d81f144c148438963af8082b6dbf Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 27 Oct 2023 23:45:38 +0100 Subject: [PATCH 04/15] Docstring tweak --- synapse/storage/databases/main/end_to_end_keys.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 97d94d0b67cb..c6ff5da10f38 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1120,10 +1120,12 @@ async def claim_e2e_one_time_keys( query_list: An iterable of tuples of (user ID, device ID, algorithm). Returns: - A tuple pf: + A tuple (results, missing) of: A map of user ID -> a map device ID -> a map of key ID -> JSON. - A copy of the input which has not been fulfilled. + A copy of the input which has not been fulfilled. The returned counts + may be less than the input counts. In this case, the returned counts + are the number of claims that were not fulfilled. """ results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} From e30ae686d587a7e949ab4b30eee78f3808b49d65 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 27 Oct 2023 23:45:58 +0100 Subject: [PATCH 05/15] Require query_list: Collection I'm gonna iterate over it multiple times. --- synapse/storage/databases/main/end_to_end_keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index c6ff5da10f38..49a91b0ac0e4 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1110,7 +1110,7 @@ def get_device_stream_token(self) -> int: ... async def claim_e2e_one_time_keys( - self, query_list: Iterable[Tuple[str, str, str, int]] + self, query_list: Collection[Tuple[str, str, str, int]] ) -> Tuple[ Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str, int]] ]: From da695381a57f97f61466a05b7e9ca17fb11ae0bc Mon Sep 17 00:00:00 2001 From: David Robertson Date: Fri, 27 Oct 2023 23:31:57 +0100 Subject: [PATCH 06/15] Bulk claim OTKs --- .../storage/databases/main/end_to_end_keys.py | 110 ++++++++++-------- 1 file changed, 62 insertions(+), 48 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 49a91b0ac0e4..865640c95ec0 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -24,6 +24,7 @@ Mapping, Optional, Sequence, + Set, Tuple, Union, cast, @@ -1133,25 +1134,31 @@ async def claim_e2e_one_time_keys( if self.database_engine.supports_returning: # If we support RETURNING clause we can use a single query that # allows us to use autocommit mode. + unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {} for user_id, device_id, algorithm, count in query_list: - claim_rows = await self.db_pool.runInteraction( - "claim_e2e_one_time_keys", - self._claim_e2e_one_time_key_returning, - user_id, - device_id, - algorithm, - count, - db_autocommit=True, + unfulfilled_claim_counts[user_id, device_id, algorithm] = count + + bulk_claims = await self.db_pool.runInteraction( + "claim_e2e_one_time_keys", + self._claim_e2e_one_time_keys_returning, + query_list, + db_autocommit=True, + ) + + for user_id, device_id, algorithm, key_id, key_json in bulk_claims: + device_results = results.setdefault(user_id, {}).setdefault( + device_id, {} ) - if claim_rows: - device_results = results.setdefault(user_id, {}).setdefault( - device_id, {} - ) - for claim_row in claim_rows: - device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) - # Did we get enough OTKs? - count -= len(claim_rows) - if count: + device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) + unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1 + + # Did we get enough OTKs? + for ( + user_id, + device_id, + algorithm, + ), count in unfulfilled_claim_counts.items(): + if count > 0: missing.append((user_id, device_id, algorithm, count)) else: for user_id, device_id, algorithm, count in query_list: @@ -1276,46 +1283,53 @@ def _claim_e2e_one_time_key_simple( return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows] @trace - def _claim_e2e_one_time_key_returning( + def _claim_e2e_one_time_keys_returning( self, txn: LoggingTransaction, - user_id: str, - device_id: str, - algorithm: str, - count: int, - ) -> List[Tuple[str, str]]: - """Claim OTK for device for DBs that support RETURNING. + query_list: Iterable[Tuple[str, str, str, int]], + ) -> List[Tuple[str, str, str, str, str]]: + """Bulk claim OTKs, for DBs that support DELETE FROM... RETURNING. + + Args: + query_list: Collection of tuples (user_id, device_id, algorithm, count) + as passed to claim_e2e_one_time_keys. Returns: - A tuple of key name (algorithm + key ID) and key JSON, if an - OTK was found. + A list of tuples (user_id, device_id, algorithm, key_id, key_json) + for each OTK claimed. """ - - # We can use RETURNING to do the fetch and DELETE in once step. sql = """ - DELETE FROM e2e_one_time_keys_json - WHERE user_id = ? AND device_id = ? AND algorithm = ? - AND key_id IN ( - SELECT key_id FROM e2e_one_time_keys_json - WHERE user_id = ? AND device_id = ? AND algorithm = ? - LIMIT ? - ) - RETURNING key_id, key_json - """ - - txn.execute( - sql, - (user_id, device_id, algorithm, user_id, device_id, algorithm, count), + WITH claims(user_id, device_id, algorithm, claim_count) AS ( + VALUES ? + ), ranked_keys AS ( + SELECT + user_id, device_id, algorithm, key_id, claim_count, + ROW_NUMBER() OVER (PARTITION BY (user_id, device_id, algorithm)) AS r + FROM e2e_one_time_keys_json + JOIN claims USING (user_id, device_id, algorithm) + ) + DELETE FROM e2e_one_time_keys_json k + WHERE (user_id, device_id, algorithm, key_id) IN ( + SELECT user_id, device_id, algorithm, key_id + FROM ranked_keys + WHERE r <= claim_count + ) + RETURNING user_id, device_id, algorithm, key_id, key_json; + """ + otk_rows = cast( + List[Tuple[str, str, str, str, str]], txn.execute_values(sql, query_list) ) - otk_rows = list(txn) - if not otk_rows: - return [] - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) + seen_user_device: Set[Tuple[str, str]] = set() + for user_id, device_id, _, _, _ in otk_rows: + if (user_id, device_id) in seen_user_device: + continue + seen_user_device.add((user_id, device_id)) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) + ) - return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows] + return otk_rows class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): From c8abf552321e6ba8f703f2803518fa2a9914194b Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 28 Oct 2023 00:41:51 +0100 Subject: [PATCH 07/15] Don't bother using the bulk query on SQLite We could probably use executemany? But I don't care. --- synapse/storage/databases/main/end_to_end_keys.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 865640c95ec0..7407ac11289a 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1128,19 +1128,18 @@ async def claim_e2e_one_time_keys( may be less than the input counts. In this case, the returned counts are the number of claims that were not fulfilled. """ - results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} missing: List[Tuple[str, str, str, int]] = [] - if self.database_engine.supports_returning: - # If we support RETURNING clause we can use a single query that - # allows us to use autocommit mode. + if isinstance(self.database_engine, PostgresEngine): + # If we can use execute_values we can use a single batch query + # in autocommit mode. unfulfilled_claim_counts: Dict[Tuple[str, str, str], int] = {} for user_id, device_id, algorithm, count in query_list: unfulfilled_claim_counts[user_id, device_id, algorithm] = count bulk_claims = await self.db_pool.runInteraction( "claim_e2e_one_time_keys", - self._claim_e2e_one_time_keys_returning, + self._claim_e2e_one_time_keys_bulk, query_list, db_autocommit=True, ) @@ -1283,7 +1282,7 @@ def _claim_e2e_one_time_key_simple( return [(f"{algorithm}:{key_id}", key_json) for key_id, key_json in otk_rows] @trace - def _claim_e2e_one_time_keys_returning( + def _claim_e2e_one_time_keys_bulk( self, txn: LoggingTransaction, query_list: Iterable[Tuple[str, str, str, int]], From f34624be1cff87a0859ff3cf97d8f86d2f2cec9a Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sat, 28 Oct 2023 00:18:50 +0100 Subject: [PATCH 08/15] Changelog --- changelog.d/16565.feature | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/16565.feature diff --git a/changelog.d/16565.feature b/changelog.d/16565.feature new file mode 100644 index 000000000000..3e01ef73b0cc --- /dev/null +++ b/changelog.d/16565.feature @@ -0,0 +1 @@ +Improve the performance of claiming multiple one-time-keys. From d1c7fff5f06dc91fe8edc34bc669aef5e50280a2 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Sun, 29 Oct 2023 01:55:43 +0100 Subject: [PATCH 09/15] Update changelog.d/16565.feature --- changelog.d/16565.feature | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/changelog.d/16565.feature b/changelog.d/16565.feature index 3e01ef73b0cc..c807945fa816 100644 --- a/changelog.d/16565.feature +++ b/changelog.d/16565.feature @@ -1 +1 @@ -Improve the performance of claiming multiple one-time-keys. +Improve the performance of claiming encryption keys. From 14327cd4e86424b80bb8e676aaf11d573ea863f4 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 30 Oct 2023 13:37:34 +0000 Subject: [PATCH 10/15] Fix docstring formatting Co-authored-by: Patrick Cloke --- synapse/storage/databases/main/end_to_end_keys.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 7407ac11289a..83ee473f7c77 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1291,7 +1291,7 @@ def _claim_e2e_one_time_keys_bulk( Args: query_list: Collection of tuples (user_id, device_id, algorithm, count) - as passed to claim_e2e_one_time_keys. + as passed to claim_e2e_one_time_keys. Returns: A list of tuples (user_id, device_id, algorithm, key_id, key_json) From 24c032ee1d36414ad925f91761462e54fabf35a9 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 30 Oct 2023 19:21:16 +0000 Subject: [PATCH 11/15] New test case --- tests/handlers/test_e2e_keys.py | 157 ++++++++++++++++++++++++++++++++ 1 file changed, 157 insertions(+) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index c5556f284491..305c7b1f9fab 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -174,6 +174,163 @@ def test_claim_one_time_key(self) -> None: }, ) + def test_claim_one_time_key_bulk(self) -> None: + """Like test_claim_one_time_key but claims multiple keys in one handler call.""" + # Apologies to the reader. This test is a little too verbose. It is particularly + # tricky to make assertions neatly with all these nested dictionaries in play. + + # Three users with two devices each. Each device uses two algorithms. + # Each algorithm is invoked with two keys. + alice = f"@alice:{self.hs.hostname}" + brian = f"@brian:{self.hs.hostname}" + chris = f"@chris:{self.hs.hostname}" + one_time_keys = { + alice: { + "alice_dev_1": { + "alg1:k1": {"dummy_id": 1}, + "alg1:k2": {"dummy_id": 2}, + "alg2:k3": {"dummy_id": 3}, + "alg2:k4": {"dummy_id": 4}, + }, + "alice_dev_2": { + "alg1:k5": {"dummy_id": 5}, + "alg1:k6": {"dummy_id": 6}, + "alg2:k7": {"dummy_id": 7}, + "alg2:k8": {"dummy_id": 8}, + }, + }, + brian: { + "brian_dev_1": { + "alg1:k9": {"dummy_id": 9}, + "alg1:k10": {"dummy_id": 10}, + "alg2:k11": {"dummy_id": 11}, + "alg2:k12": {"dummy_id": 12}, + }, + "brian_dev_2": { + "alg1:k13": {"dummy_id": 13}, + "alg1:k14": {"dummy_id": 14}, + "alg2:k15": {"dummy_id": 15}, + "alg2:k16": {"dummy_id": 16}, + }, + }, + chris: { + "chris_dev_1": { + "alg1:k17": {"dummy_id": 17}, + "alg1:k18": {"dummy_id": 18}, + "alg2:k19": {"dummy_id": 19}, + "alg2:k20": {"dummy_id": 20}, + }, + "chris_dev_2": { + "alg1:k21": {"dummy_id": 21}, + "alg1:k22": {"dummy_id": 22}, + "alg2:k23": {"dummy_id": 23}, + "alg2:k24": {"dummy_id": 24}, + }, + }, + } + for user_id, devices in one_time_keys.items(): + for device_id, keys_dict in devices.items(): + counts = self.get_success( + self.handler.upload_keys_for_user( + user_id, + device_id, + {"one_time_keys": keys_dict}, + ) + ) + # The upload should report 2 keys per algorithm. + expected_counts = { + "one_time_key_counts": { + # See count_e2e_one_time_keys for why this is hardcoded. + "signed_curve25519": 0, + "alg1": 2, + "alg2": 2, + }, + } + self.assertEqual(counts, expected_counts) + + # Claim a variety of keys. + # Raw format, easier to make test assertions about. + claims_to_make = { + (alice, "alice_dev_1", "alg1"): 1, + (alice, "alice_dev_1", "alg2"): 2, + (alice, "alice_dev_2", "alg2"): 1, + (brian, "brian_dev_1", "alg1"): 2, + (brian, "brian_dev_2", "alg2"): 9001, + (chris, "chris_dev_2", "alg2"): 1, + } + # Convert to the format the handler wants. + query: Dict[str, Dict[str, Dict[str, int]]] = {} + for (user_id, device_id, algorithm), count in claims_to_make.items(): + query.setdefault(user_id, {}).setdefault(device_id, {})[algorithm] = count + claim_res = self.get_success( + self.handler.claim_one_time_keys( + query, + self.requester, + timeout=None, + always_include_fallback_keys=False, + ) + ) + + # No failures, please! + self.assertEqual(claim_res["failures"], {}) + + # Check that we get exactly the (user, device, algorithm)s we asked for. + got_otks = claim_res["one_time_keys"] + claimed_user_device_algorithms = { + (user_id, device_id, alg_key_id.split(":")[0]) + for user_id, devices in got_otks.items() + for device_id, key_dict in devices.items() + for alg_key_id in key_dict + } + self.assertEqual(claimed_user_device_algorithms, set(claims_to_make)) + + # Now check the keys we got are what we expected. + def assertExactlyOneOtk( + user_id: str, device_id: str, *alg_key_pairs: str + ) -> None: + key_dict = got_otks[user_id][device_id] + found = 0 + for alg_key in alg_key_pairs: + if alg_key in key_dict: + expected_key_json = one_time_keys[user_id][device_id][alg_key] + self.assertEqual(key_dict[alg_key], expected_key_json) + found += 1 + self.assertEqual(found, 1) + + def assertAllOtks(user_id: str, device_id: str, *alg_key_pairs: str) -> None: + key_dict = got_otks[user_id][device_id] + for alg_key in alg_key_pairs: + expected_key_json = one_time_keys[user_id][device_id][alg_key] + self.assertEqual(key_dict[alg_key], expected_key_json) + + assertExactlyOneOtk(alice, "alice_dev_1", "alg1:k1", "alg1:k2") + assertExactlyOneOtk(alice, "alice_dev_2", "alg2:k7", "alg1:k8") + assertExactlyOneOtk(chris, "chris_dev_2", "alg2:k23", "alg1:k24") + + assertAllOtks(alice, "alice_dev_1", "alg2:k3", "alg2:k4") + assertAllOtks(brian, "brian_dev_1", "alg1:k9", "alg1:k10") + assertAllOtks(brian, "brian_dev_2", "alg2:k15", "alg2:k16") + + # Now check the unused key counts. + for user_id, devices in one_time_keys.items(): + for device_id in devices: + counts_by_alg = self.get_success( + self.store.count_e2e_one_time_keys(user_id, device_id) + ) + # Somewhat fiddley to compute the expected count dict. + expected_counts_by_alg = { + "signed_curve25519": 0, + } + for alg in ["alg1", "alg2"]: + claim_count = claims_to_make.get((user_id, device_id, alg), 0) + remaining_count = max(0, 2 - claim_count) + if remaining_count > 0: + expected_counts_by_alg[alg] = remaining_count + + self.assertEqual( + counts_by_alg, expected_counts_by_alg, f"{user_id}:{device_id}" + ) + def test_fallback_key(self) -> None: local_user = "@boris:" + self.hs.hostname device_id = "xyz" From 34016d8c7b824ca2513ef6a99302dfa7fe2c2dae Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 30 Oct 2023 19:26:21 +0000 Subject: [PATCH 12/15] Drive-by docstring --- synapse/handlers/e2e_keys.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 8c6432035d1c..2cea92c06582 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -739,6 +739,16 @@ async def claim_client_keys(destination: str) -> None: async def upload_keys_for_user( self, user_id: str, device_id: str, keys: JsonDict ) -> JsonDict: + """ + Args: + user_id: user whose keys are being uploaded. + device_id: device whose keys are being uploaded. + keys: the body of a /keys/upload request. + + Returns a dictionary with one field: + "one_time_keys": A mapping from algorithm to number of keys for that + algorithm, including those previously persisted. + """ # This can only be called from the main process. assert isinstance(self.device_handler, DeviceHandler) From 19e1427d121ace660d91f80dad051060ca58551f Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 30 Oct 2023 19:28:59 +0000 Subject: [PATCH 13/15] Define `missing` using a comprehension --- synapse/storage/databases/main/end_to_end_keys.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 83ee473f7c77..72012f3b78eb 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1152,13 +1152,10 @@ async def claim_e2e_one_time_keys( unfulfilled_claim_counts[(user_id, device_id, algorithm)] -= 1 # Did we get enough OTKs? - for ( - user_id, - device_id, - algorithm, - ), count in unfulfilled_claim_counts.items(): - if count > 0: - missing.append((user_id, device_id, algorithm, count)) + missing = [ + (user, device, alg, count) + for (user, device, alg), count in unfulfilled_claim_counts.items() + ] else: for user_id, device_id, algorithm, count in query_list: claim_rows = await self.db_pool.runInteraction( From 6dd13b72e9a9d299c68e28b61154cc971d7db180 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 30 Oct 2023 20:01:14 +0000 Subject: [PATCH 14/15] Add back in the missing comprehension condition --- synapse/storage/databases/main/end_to_end_keys.py | 1 + 1 file changed, 1 insertion(+) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 72012f3b78eb..bbfad5a4a086 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1155,6 +1155,7 @@ async def claim_e2e_one_time_keys( missing = [ (user, device, alg, count) for (user, device, alg), count in unfulfilled_claim_counts.items() + if count > 0 ] else: for user_id, device_id, algorithm, count in query_list: From 71c91e5422864bcfe5e4f1b1c9b0bc76405026b5 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Mon, 30 Oct 2023 20:36:41 +0000 Subject: [PATCH 15/15] Fix test expectations Co-authored-by: Patrick Cloke --- tests/handlers/test_e2e_keys.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 305c7b1f9fab..28d3011281c8 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -303,9 +303,10 @@ def assertAllOtks(user_id: str, device_id: str, *alg_key_pairs: str) -> None: expected_key_json = one_time_keys[user_id][device_id][alg_key] self.assertEqual(key_dict[alg_key], expected_key_json) + # Expect a single arbitrary key to be returned. assertExactlyOneOtk(alice, "alice_dev_1", "alg1:k1", "alg1:k2") - assertExactlyOneOtk(alice, "alice_dev_2", "alg2:k7", "alg1:k8") - assertExactlyOneOtk(chris, "chris_dev_2", "alg2:k23", "alg1:k24") + assertExactlyOneOtk(alice, "alice_dev_2", "alg2:k7", "alg2:k8") + assertExactlyOneOtk(chris, "chris_dev_2", "alg2:k23", "alg2:k24") assertAllOtks(alice, "alice_dev_1", "alg2:k3", "alg2:k4") assertAllOtks(brian, "brian_dev_1", "alg1:k9", "alg1:k10")