Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Ensure forked threads are not allowed.
Browse files Browse the repository at this point in the history
  • Loading branch information
clokep committed Oct 22, 2021
1 parent 1fff047 commit a750347
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 0 deletions.
7 changes: 7 additions & 0 deletions synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,6 +1048,13 @@ async def _validate_event_relation(self, event: EventBase) -> None:
if already_exists:
raise SynapseError(400, "Can't send same reaction twice")

# If this relation is a thread, then ensure thread head is not part of
# a thread already.
elif relation_type == RelationTypes.THREAD:
already_thread = await self.store.get_event_thread(relates_to)
if already_thread:
raise SynapseError(400, "Can't fork threads")

@measure_func("handle_new_client_event")
async def handle_new_client_event(
self,
Expand Down
31 changes: 31 additions & 0 deletions synapse/storage/databases/main/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,37 @@ def _get_if_user_has_annotated_event(txn):
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
)

async def get_event_thread(self, event_id: str) -> Optional[str]:
"""Return an event's thread.
Args:
event_id: The event being used as the start of a new thread.
Returns:
The thread ID of the event.
"""

sql = """
SELECT relates_to_id FROM event_relations
WHERE
event_id = ?
AND relation_type = ?
LIMIT 1;
"""

def _get_thread_id(txn) -> Optional[str]:
txn.execute(
sql,
(
event_id,
RelationTypes.THREAD,
),
)

return txn.fetchone()

return await self.db_pool.runInteraction("get_thread_id", _get_thread_id)


class RelationsStore(RelationsWorkerStore):
pass
19 changes: 19 additions & 0 deletions tests/rest/client/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,25 @@ def test_deny_double_react(self):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
self.assertEquals(400, channel.code, channel.json_body)

def test_deny_forked_thread(self):
"""It is invalid to start a thread off a thread."""
channel = self._send_relation(
RelationTypes.THREAD,
"m.room.message",
content={"msgtype": "m.text", "body": "foo"},
parent_id=self.parent_id,
)
self.assertEquals(200, channel.code, channel.json_body)
parent_id = channel.json_body["event_id"]

channel = self._send_relation(
RelationTypes.THREAD,
"m.room.message",
content={"msgtype": "m.text", "body": "foo"},
parent_id=parent_id,
)
self.assertEquals(400, channel.code, channel.json_body)

def test_basic_paginate_relations(self):
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
Expand Down

0 comments on commit a750347

Please sign in to comment.