Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Do not allow a thread to start for any event with a relation.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Nov 11, 2021
1 parent 0c15d9c commit 8e63d56
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 38 deletions.
12 changes: 5 additions & 7 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,10 +1031,7 @@ async def _validate_event_relation(self, event: EventBase) -> None:
if not parent_event:
# There must be some reason that the client knows the event exists,
# see if there are existing relations. If so, assume everything is fine.
pagination = await self.store.get_relations_for_event(
relates_to, limit=1
)
if pagination.chunk:
if await self.store.event_has_relations(relates_to):
return

# Otherwise, the client can't know about the parent event!
Expand All @@ -1059,9 +1056,10 @@ async def _validate_event_relation(self, event: EventBase) -> None:
# If this relation is a thread, then ensure thread head is not part of
# a thread already.
elif relation_type == RelationTypes.THREAD:
already_thread = await self.store.get_event_thread(relates_to)
if already_thread:
raise SynapseError(400, "Can't fork threads")
if await self.store.event_has_relations(relates_to):
raise SynapseError(
400, "Cannot start threads from an event with a relation"
)

@measure_func("handle_new_client_event")
async def handle_new_client_event(
Expand Down
50 changes: 19 additions & 31 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,25 @@ def _get_recent_references_for_event_txn(
"get_recent_references_for_event", _get_recent_references_for_event_txn
)

async def event_has_relations(self, parent_id: str) -> bool:
"""Check if a given event has any known relations in the database.
Args:
parent_id: The event to check.
Returns:
True if the event has any relations.
"""

result = await self.db_pool.simple_select_one_onecol(
table="event_relations",
keyvalues={"relates_to_id": parent_id},
retcol="event_id",
allow_none=True,
desc="event_has_relations",
)
return result is not None

@cached(tree=True)
async def get_aggregation_groups_for_event(
self,
Expand Down Expand Up @@ -436,37 +455,6 @@ def _get_if_user_has_annotated_event(txn: LoggingTransaction) -> bool:
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)

async def get_event_thread(self, event_id: str) -> Optional[str]:
"""Return an event's thread.
Args:
event_id: The event being used as the start of a new thread.
Returns:
The thread ID of the event.
"""

sql = """
SELECT relates_to_id FROM event_relations
WHERE
event_id = ?
AND relation_type = ?
LIMIT 1;
"""

def _get_thread_id(txn) -> Optional[str]:
txn.execute(
sql,
(
event_id,
RelationTypes.THREAD,
),
)

return txn.fetchone()

return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)


class RelationsStore(RelationsWorkerStore):
pass

0 comments on commit 8e63d56

Please sign in to comment.