From 600a665c806c9d20db6d57094c15f9169b03f45a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 9 Mar 2021 11:20:35 -0500 Subject: [PATCH 1/8] Pass the room ID to get_auth_chain_ids. --- synapse/federation/federation_server.py | 6 ++++-- synapse/handlers/federation.py | 6 +++--- synapse/storage/databases/main/event_federation.py | 9 ++++++--- 3 files changed, 13 insertions(+), 8 deletions(-) diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index ffc735ba254c..06c5e7a9e0f3 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -447,7 +447,7 @@ async def on_state_ids_request( async def _on_state_ids_request_compute(self, room_id, event_id): state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) + auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids) return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} async def _on_context_state_request_compute( @@ -460,7 +460,9 @@ async def _on_context_state_request_compute( else: pdus = (await self.state.get_current_state(room_id)).values() - auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + auth_chain = await self.store.get_auth_chain( + room_id, [pdu.event_id for pdu in pdus] + ) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2ead626a4d5a..3fe02b719595 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1317,7 +1317,7 @@ async def send_invite(self, target_host, event): async def on_event_auth(self, event_id: str) -> List[EventBase]: event = await self.store.get_event(event_id) auth = await self.store.get_auth_chain( - list(event.auth_event_ids()), include_given=True + event.room_id, list(event.auth_event_ids()), include_given=True ) return list(auth) @@ -1580,7 +1580,7 @@ async def on_send_join_request(self, origin, pdu): prev_state_ids = await context.get_prev_state_ids() state_ids = list(prev_state_ids.values()) - auth_chain = await self.store.get_auth_chain(state_ids) + auth_chain = await self.store.get_auth_chain(event.room_id, state_ids) state = await self.store.get_events(list(prev_state_ids.values())) @@ -2219,7 +2219,7 @@ async def on_query_auth( # Now get the current auth_chain for the event. local_auth_chain = await self.store.get_auth_chain( - list(event.auth_event_ids()), include_given=True + room_id, list(event.auth_event_ids()), include_given=True ) # TODO: Check if we would now reject event_id. If so we need to tell diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 18ddb92fcca5..24b08e2f7e13 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -54,11 +54,12 @@ def __init__(self, database: DatabasePool, db_conn, hs): ) # type: LruCache[str, List[Tuple[str, int]]] async def get_auth_chain( - self, event_ids: Collection[str], include_given: bool = False + self, room_id: str, event_ids: Collection[str], include_given: bool = False ) -> List[EventBase]: """Get auth events for given event_ids. The events *must* be state events. Args: + room_id: The room the event is in. event_ids: state events include_given: include the given events in result @@ -66,23 +67,25 @@ async def get_auth_chain( list of events """ event_ids = await self.get_auth_chain_ids( - event_ids, include_given=include_given + room_id, event_ids, include_given=include_given ) return await self.get_events_as_list(event_ids) async def get_auth_chain_ids( self, + room_id: str, event_ids: Collection[str], include_given: bool = False, ) -> List[str]: """Get auth events for given event_ids. The events *must* be state events. Args: + room_id: The room the event is in. event_ids: state events include_given: include the given events in result Returns: - An awaitable which resolve to a list of event_ids + list of event_ids """ return await self.db_pool.runInteraction( "get_auth_chain_ids", From cd218d07f4034644ecaec48944d5e1afe49e3780 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 9 Mar 2021 11:50:40 -0500 Subject: [PATCH 2/8] Add tests for get_auth_chain_ids. --- tests/storage/test_event_federation.py | 75 ++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 06000f81a63d..66730c060265 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -118,8 +118,7 @@ def insert_event(txn, i, room_id): r = self.get_success(self.store.get_rooms_with_many_extremities(5, 1, [room1])) self.assertTrue(r == [room2] or r == [room3]) - @parameterized.expand([(True,), (False,)]) - def test_auth_difference(self, use_chain_cover_index: bool): + def _setup_auth_chain(self, use_chain_cover_index: bool) -> str: room_id = "@ROOM:local" # The silly auth graph we use to test the auth difference algorithm, @@ -165,7 +164,7 @@ def test_auth_difference(self, use_chain_cover_index: bool): "j": 1, } - # Mark the room as not having a cover index + # Mark the room as maybe having a cover index. def store_room(txn): self.store.db_pool.simple_insert_txn( @@ -222,6 +221,76 @@ def insert_event(txn): ) ) + return room_id + + def test_auth_chain_ids(self, use_chain_cover_index: bool): + room_id = self._setup_auth_chain(False) + + # a and b have the same auth chain. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"])) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["b"])) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["a", "b"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["c"])) + self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"]) + + # d and e have the same auth chain. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["d"])) + self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"]) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["e"])) + self.assertCountEqual(auth_chain_ids, ["f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["f"])) + self.assertCountEqual(auth_chain_ids, ["g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["g"])) + self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"])) + self.assertEqual(auth_chain_ids, ["k"]) + + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"])) + self.assertEqual(auth_chain_ids, ["j"]) + + # j and k have no parents. + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"])) + self.assertEqual(auth_chain_ids, []) + auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"])) + self.assertEqual(auth_chain_ids, []) + + # More complex input sequences. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["b", "c", "d"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["h", "i"]) + ) + self.assertCountEqual(auth_chain_ids, ["k", "j"]) + + # e gets returned even though include_given is false, but it is in the + # auth chain of b. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["b", "e"]) + ) + self.assertCountEqual(auth_chain_ids, ["e", "f", "g", "h", "i", "j", "k"]) + + # Test include_given. + auth_chain_ids = self.get_success( + self.store.get_auth_chain_ids(room_id, ["i"], include_given=True) + ) + self.assertCountEqual(auth_chain_ids, ["i", "j"]) + + @parameterized.expand([(True,), (False,)]) + def test_auth_difference(self, use_chain_cover_index: bool): + room_id = self._setup_auth_chain(use_chain_cover_index) + # Now actually test that various combinations give the right result: difference = self.get_success( From 7901b6f2f3b32acc0ccc95779f9fe3af3a83c47b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 9 Mar 2021 13:09:21 -0500 Subject: [PATCH 3/8] Use the chain cover to calculate auth events. --- .../databases/main/event_federation.py | 127 ++++++++++++++++++ tests/storage/test_event_federation.py | 3 +- 2 files changed, 129 insertions(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 24b08e2f7e13..1bbce000223b 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -87,6 +87,24 @@ async def get_auth_chain_ids( Returns: list of event_ids """ + + # Check if we have indexed the room so we can use the chain cover + # algorithm. + room = await self.get_room(room_id) + if room["has_auth_chain_index"]: + try: + return await self.db_pool.runInteraction( + "get_auth_chain_ids_chains", + self._get_auth_chain_ids_using_cover_index_txn, + room_id, + event_ids, + include_given, + ) + except _NoChainCoverIndex: + # For whatever reason we don't actually have a chain cover index + # for the events in question, so we fall back to the old method. + pass + return await self.db_pool.runInteraction( "get_auth_chain_ids", self._get_auth_chain_ids_txn, @@ -94,9 +112,118 @@ async def get_auth_chain_ids( include_given, ) + def _get_auth_chain_ids_using_cover_index_txn( + self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool + ) -> List[str]: + """Calculates the auth chain IDs using the chain index.""" + + # First we look up the chain ID/sequence numbers for all the events, and + # work out the chain/sequence numbers reachable from each state set. + + initial_events = set(event_ids) + + # All the events that we've found that are reachable from the events. + seen_events = set() # type: Set[str] + + # A map from chain ID to max sequence number reachable from any event ID. + chains = {} # type: Dict[int, int] + + sql = """ + SELECT event_id, chain_id, sequence_number + FROM event_auth_chains + WHERE %s + """ + for batch in batch_iter(initial_events, 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "event_id", batch + ) + txn.execute(sql % (clause,), args) + + for event_id, chain_id, sequence_number in txn: + seen_events.add(event_id) + chains[chain_id] = max(sequence_number, chains.get(chain_id, 0)) + + # Check that we actually have a chain ID for all the events. + events_missing_chain_info = initial_events.difference(seen_events) + if events_missing_chain_info: + # This can happen due to e.g. downgrade/upgrade of the server. We + # raise an exception and fall back to the previous algorithm. + logger.info( + "Unexpectedly found that events don't have chain IDs in room %s: %s", + room_id, + events_missing_chain_info, + ) + raise _NoChainCoverIndex(room_id) + + # Now we look up all links for the chains we have, adding chains that + # are reachable from each set. + sql = """ + SELECT + origin_chain_id, origin_sequence_number, + target_chain_id, target_sequence_number + FROM event_auth_chain_links + WHERE %s + """ + + # (We need to take a copy of `chains` as we want to mutate it in the loop) + for batch in batch_iter(set(chains), 1000): + clause, args = make_in_list_sql_clause( + txn.database_engine, "origin_chain_id", batch + ) + txn.execute(sql % (clause,), args) + + for ( + origin_chain_id, + origin_sequence_number, + target_chain_id, + target_sequence_number, + ) in txn: + if origin_sequence_number <= chains.get(origin_chain_id, 0): + chains[target_chain_id] = max( + target_sequence_number, + chains.get(target_chain_id, 0), + ) + + # Now for each chain we figure out the maximum sequence number reachable + # from *any* event ID. Events with a sequence less than that are in the + # auth chain. + if include_given: + results = initial_events + else: + results = set() + + if isinstance(self.database_engine, PostgresEngine): + # We can use `execute_values` to efficiently fetch the gaps when + # using postgres. + sql = """ + SELECT event_id + FROM event_auth_chains AS c, (VALUES ?) AS l(chain_id, max_seq) + WHERE + c.chain_id = l.chain_id + AND sequence_number <= max_seq + """ + + rows = txn.execute_values(sql, chains.items()) + results.update(r for r, in rows) + else: + # For SQLite we just fall back to doing a noddy for loop. + sql = """ + SELECT event_id FROM event_auth_chains + WHERE chain_id = ? AND sequence_number <= ? + """ + for chain_id, max_no in chains.items(): + txn.execute(sql, (chain_id, max_no)) + results.update(r for r, in txn) + + return list(results) + def _get_auth_chain_ids_txn( self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool ) -> List[str]: + """Calculates the auth chain IDs. + + This is used when we don't have a cover index for the room. + """ if include_given: results = set(event_ids) else: diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 66730c060265..d597d712d675 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -223,8 +223,9 @@ def insert_event(txn): return room_id + @parameterized.expand([(True,), (False,)]) def test_auth_chain_ids(self, use_chain_cover_index: bool): - room_id = self._setup_auth_chain(False) + room_id = self._setup_auth_chain(use_chain_cover_index) # a and b have the same auth chain. auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["a"])) From af74e4b36981f0bc5e516d6352f590f106e683ac Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 9 Mar 2021 13:43:21 -0500 Subject: [PATCH 4/8] Only include the given events when requested. --- .../databases/main/event_federation.py | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 1bbce000223b..f545f5ca2c74 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -125,8 +125,8 @@ def _get_auth_chain_ids_using_cover_index_txn( # All the events that we've found that are reachable from the events. seen_events = set() # type: Set[str] - # A map from chain ID to max sequence number reachable from any event ID. - chains = {} # type: Dict[int, int] + # A map from chain ID to max sequence number of the given events. + event_chains = {} # type: Dict[int, int] sql = """ SELECT event_id, chain_id, sequence_number @@ -141,7 +141,9 @@ def _get_auth_chain_ids_using_cover_index_txn( for event_id, chain_id, sequence_number in txn: seen_events.add(event_id) - chains[chain_id] = max(sequence_number, chains.get(chain_id, 0)) + event_chains[chain_id] = max( + sequence_number, event_chains.get(chain_id, 0) + ) # Check that we actually have a chain ID for all the events. events_missing_chain_info = initial_events.difference(seen_events) @@ -165,8 +167,12 @@ def _get_auth_chain_ids_using_cover_index_txn( WHERE %s """ - # (We need to take a copy of `chains` as we want to mutate it in the loop) - for batch in batch_iter(set(chains), 1000): + # A map from chain ID to max sequence number *reachable* from any event ID. + chains = dict(event_chains) + + # We need to take a copy of `event_chains` as we need to separate chain + # IDs / seq nos from the given events vs. reachable ones. + for batch in batch_iter(event_chains, 1000): clause, args = make_in_list_sql_clause( txn.database_engine, "origin_chain_id", batch ) @@ -184,6 +190,20 @@ def _get_auth_chain_ids_using_cover_index_txn( chains.get(target_chain_id, 0), ) + # The chain ID / seq no of a given event is reachable from + # a different event, discard that chain from the given events. + if ( + target_chain_id in event_chains + and target_sequence_number >= event_chains[target_chain_id] + ): + event_chains.pop(target_chain_id) + + # Don't include the given events (since we're finding the auth chain of + # those events). + for chain_id in event_chains: + if event_chains[chain_id] == chains[chain_id]: + chains[chain_id] -= 1 + # Now for each chain we figure out the maximum sequence number reachable # from *any* event ID. Events with a sequence less than that are in the # auth chain. From 472f80be9ff3dd27861e44fa91765858430c1180 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 9 Mar 2021 14:12:09 -0500 Subject: [PATCH 5/8] Newsfragment --- changelog.d/9576.misc | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/9576.misc diff --git a/changelog.d/9576.misc b/changelog.d/9576.misc new file mode 100644 index 000000000000..bc257d05b744 --- /dev/null +++ b/changelog.d/9576.misc @@ -0,0 +1 @@ +Improve efficiency of calculating the auth chain in large rooms. From 6dc3f8a3dbfc86d91412981c0e601372f4434798 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 9 Mar 2021 14:17:08 -0500 Subject: [PATCH 6/8] Clarify comments. --- synapse/storage/databases/main/event_federation.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index f545f5ca2c74..1ae88c9f56ac 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -117,8 +117,7 @@ def _get_auth_chain_ids_using_cover_index_txn( ) -> List[str]: """Calculates the auth chain IDs using the chain index.""" - # First we look up the chain ID/sequence numbers for all the events, and - # work out the chain/sequence numbers reachable from each state set. + # First we look up the chain ID/sequence numbers for the given events. initial_events = set(event_ids) @@ -158,7 +157,7 @@ def _get_auth_chain_ids_using_cover_index_txn( raise _NoChainCoverIndex(room_id) # Now we look up all links for the chains we have, adding chains that - # are reachable from each set. + # are reachable from any event. sql = """ SELECT origin_chain_id, origin_sequence_number, From 7dd9d3180403600b2b23af4e5f3a9b7f2857a4f4 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 10 Mar 2021 07:45:23 -0500 Subject: [PATCH 7/8] Simplify handling of include_given. --- .../databases/main/event_federation.py | 24 ++++++------------- 1 file changed, 7 insertions(+), 17 deletions(-) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 1ae88c9f56ac..4a6938379fb9 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -167,10 +167,9 @@ def _get_auth_chain_ids_using_cover_index_txn( """ # A map from chain ID to max sequence number *reachable* from any event ID. - chains = dict(event_chains) + chains = {} # type: Dict[int, int] - # We need to take a copy of `event_chains` as we need to separate chain - # IDs / seq nos from the given events vs. reachable ones. + # Add all linked chains reachable from initial set of chains. for batch in batch_iter(event_chains, 1000): clause, args = make_in_list_sql_clause( txn.database_engine, "origin_chain_id", batch @@ -183,25 +182,16 @@ def _get_auth_chain_ids_using_cover_index_txn( target_chain_id, target_sequence_number, ) in txn: - if origin_sequence_number <= chains.get(origin_chain_id, 0): + if origin_sequence_number <= event_chains.get(origin_chain_id, 0): chains[target_chain_id] = max( target_sequence_number, chains.get(target_chain_id, 0), ) - # The chain ID / seq no of a given event is reachable from - # a different event, discard that chain from the given events. - if ( - target_chain_id in event_chains - and target_sequence_number >= event_chains[target_chain_id] - ): - event_chains.pop(target_chain_id) - - # Don't include the given events (since we're finding the auth chain of - # those events). - for chain_id in event_chains: - if event_chains[chain_id] == chains[chain_id]: - chains[chain_id] -= 1 + # Add the initial set of chains, excluding the sequence corresponding to + # initial event. + for chain_id, seq_no in event_chains.items(): + chains[chain_id] = max(seq_no - 1, chains.get(chain_id, 0)) # Now for each chain we figure out the maximum sequence number reachable # from *any* event ID. Events with a sequence less than that are in the From da5a0ff0e83f3a02936c84b096d7c6399d509c8b Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 10 Mar 2021 07:59:58 -0500 Subject: [PATCH 8/8] Add back a comment. --- synapse/storage/databases/main/event_federation.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 4a6938379fb9..332193ad1c93 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -182,6 +182,9 @@ def _get_auth_chain_ids_using_cover_index_txn( target_chain_id, target_sequence_number, ) in txn: + # chains are only reachable if the origin sequence number of + # the link is less than the max sequence number in the + # origin chain. if origin_sequence_number <= event_chains.get(origin_chain_id, 0): chains[target_chain_id] = max( target_sequence_number,