Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Bundle aggregations outside of the serialization method #11612

Merged
merged 4 commits into from
Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/11612.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid database access in the JSON serialization process.
126 changes: 37 additions & 89 deletions synapse/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,18 @@
# limitations under the License.
import collections.abc
import re
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Union,
)
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union

from frozendict import frozendict

from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
from synapse.types import JsonDict
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.frozenutils import unfreeze

from . import EventBase

if TYPE_CHECKING:
from synapse.server import HomeServer

# Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
# (?<!stuff) matches if the current position in the string is not preceded
# by a match for 'stuff'.
Expand Down Expand Up @@ -385,17 +371,12 @@ class EventClientSerializer:
clients.
"""

def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self._msc1849_enabled = hs.config.experimental.msc1849_enabled
self._msc3440_enabled = hs.config.experimental.msc3440_enabled

async def serialize_event(
def serialize_event(
self,
event: Union[JsonDict, EventBase],
time_now: int,
*,
bundle_aggregations: bool = False,
bundle_aggregations: Optional[Dict[str, JsonDict]] = None,
**kwargs: Any,
) -> JsonDict:
"""Serializes a single event.
Expand All @@ -418,66 +399,41 @@ async def serialize_event(
serialized_event = serialize_event(event, time_now, **kwargs)

# Check if there are any bundled aggregations to include with the event.
#
# Do not bundle aggregations if any of the following at true:
#
# * Support is disabled via the configuration or the caller.
# * The event is a state event.
# * The event has been redacted.
if (
self._msc1849_enabled
and bundle_aggregations
and not event.is_state()
and not event.internal_metadata.is_redacted()
Comment on lines -430 to -431
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit moved to the top of the helper function in the previous commit.

):
await self._injected_bundled_aggregations(event, time_now, serialized_event)
if bundle_aggregations:
event_aggregations = bundle_aggregations.get(event.event_id)
if event_aggregations:
self._injected_bundled_aggregations(
event,
time_now,
bundle_aggregations[event.event_id],
serialized_event,
)

return serialized_event

async def _injected_bundled_aggregations(
self, event: EventBase, time_now: int, serialized_event: JsonDict
def _injected_bundled_aggregations(
self,
event: EventBase,
time_now: int,
aggregations: JsonDict,
serialized_event: JsonDict,
) -> None:
"""Potentially injects bundled aggregations into the unsigned portion of the serialized event.

Args:
event: The event being serialized.
time_now: The current time in milliseconds
aggregations: The bundled aggregation to serialize.
serialized_event: The serialized event which may be modified.

"""
# Do not bundle aggregations for an event which represents an edit or an
# annotation. It does not make sense for them to have related events.
relates_to = event.content.get("m.relates_to")
if isinstance(relates_to, (dict, frozendict)):
relation_type = relates_to.get("rel_type")
if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
return

event_id = event.event_id
room_id = event.room_id

# The bundled aggregations to include.
aggregations = {}

annotations = await self.store.get_aggregation_groups_for_event(
event_id, room_id
)
if annotations.chunk:
aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
# Make a copy in-case the object is cached.
aggregations = aggregations.copy()
Comment on lines +430 to +431
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this? Ahh yes: there's a mutation on +469.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wondered if it would make more sense to make the input to this an attrs class so that the types are much clearer, which would mean we would build the dictionary here instead. Would that be clearer? (I think I'd rather do that as a follow-up since it is a decent amount of changes, but could do it here if you'd like!)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that'd be nice, but not crucial. Sounds like it'd be best as a separate change to me.


references = await self.store.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations[RelationTypes.REFERENCE] = references.to_dict()

edit = None
if event.type == EventTypes.Message:
edit = await self.store.get_applicable_edit(event_id, room_id)

if edit:
if RelationTypes.REPLACE in aggregations:
# If there is an edit replace the content, preserving existing
# relations.
edit = aggregations[RelationTypes.REPLACE]

# Ensure we take copies of the edit content, otherwise we risk modifying
# the original event.
Expand All @@ -502,27 +458,19 @@ async def _injected_bundled_aggregations(
}

# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
(
thread_count,
latest_thread_event,
) = await self.store.get_thread_summary(event_id, room_id)
if latest_thread_event:
aggregations[RelationTypes.THREAD] = {
# Don't bundle aggregations as this could recurse forever.
"latest_event": await self.serialize_event(
latest_thread_event, time_now, bundle_aggregations=False
),
"count": thread_count,
}

# If any bundled aggregations were found, include them.
if aggregations:
serialized_event["unsigned"].setdefault("m.relations", {}).update(
aggregations
if RelationTypes.THREAD in aggregations:
# Serialize the latest thread event.
latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]

# Don't bundle aggregations as this could recurse forever.
aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
latest_thread_event, time_now, bundle_aggregations=None
)

async def serialize_events(
# Include the bundled aggregations in the event.
serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)

def serialize_events(
self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
) -> List[JsonDict]:
"""Serializes multiple events.
Expand All @@ -535,9 +483,9 @@ async def serialize_events(
Returns:
The list of serialized events
"""
return await yieldable_gather_results(
self.serialize_event, events, time_now=time_now, **kwargs
)
return [
self.serialize_event(event, time_now=time_now, **kwargs) for event in events
]


def copy_power_levels_contents(
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def get_stream(

events.extend(to_add)

chunks = await self._event_serializer.serialize_events(
chunks = self._event_serializer.serialize_events(
events,
time_now,
as_client_event=as_client_event,
Expand Down
16 changes: 7 additions & 9 deletions synapse/handlers/initial_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ async def handle_room(event: RoomsForUser) -> None:
d["inviter"] = event.sender

invite_event = await self.store.get_event(event.event_id)
d["invite"] = await self._event_serializer.serialize_event(
d["invite"] = self._event_serializer.serialize_event(
invite_event,
time_now,
as_client_event=as_client_event,
Expand Down Expand Up @@ -222,7 +222,7 @@ async def handle_room(event: RoomsForUser) -> None:

d["messages"] = {
"chunk": (
await self._event_serializer.serialize_events(
self._event_serializer.serialize_events(
messages,
time_now=time_now,
as_client_event=as_client_event,
Expand All @@ -232,7 +232,7 @@ async def handle_room(event: RoomsForUser) -> None:
"end": await end_token.to_string(self.store),
}

d["state"] = await self._event_serializer.serialize_events(
d["state"] = self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
as_client_event=as_client_event,
Expand Down Expand Up @@ -376,16 +376,14 @@ async def _room_initial_sync_parted(
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(messages, time_now)
self._event_serializer.serialize_events(messages, time_now)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
},
"state": (
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(
room_state.values(), time_now
)
self._event_serializer.serialize_events(room_state.values(), time_now)
),
"presence": [],
"receipts": [],
Expand All @@ -404,7 +402,7 @@ async def _room_initial_sync_joined(
# TODO: These concurrently
time_now = self.clock.time_msec()
# Don't bundle aggregations as this is a deprecated API.
state = await self._event_serializer.serialize_events(
state = self._event_serializer.serialize_events(
current_state.values(), time_now
)

Expand Down Expand Up @@ -480,7 +478,7 @@ async def get_receipts() -> List[JsonDict]:
"messages": {
"chunk": (
# Don't bundle aggregations as this is a deprecated API.
await self._event_serializer.serialize_events(messages, time_now)
self._event_serializer.serialize_events(messages, time_now)
),
"start": await start_token.to_string(self.store),
"end": await end_token.to_string(self.store),
Expand Down
2 changes: 1 addition & 1 deletion synapse/handlers/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async def get_state_events(
room_state = room_state_events[membership_event_id]

now = self.clock.time_msec()
events = await self._event_serializer.serialize_events(room_state.values(), now)
events = self._event_serializer.serialize_events(room_state.values(), now)
return events

async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
Expand Down
8 changes: 5 additions & 3 deletions synapse/handlers/pagination.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,14 +537,16 @@ async def get_messages(
state_dict = await self.store.get_events(list(state_ids.values()))
state = state_dict.values()

aggregations = await self.store.get_bundled_aggregations(events)

time_now = self.clock.time_msec()

chunk = {
"chunk": (
await self._event_serializer.serialize_events(
self._event_serializer.serialize_events(
events,
time_now,
bundle_aggregations=True,
bundle_aggregations=aggregations,
as_client_event=as_client_event,
)
),
Expand All @@ -553,7 +555,7 @@ async def get_messages(
}

if state:
chunk["state"] = await self._event_serializer.serialize_events(
chunk["state"] = self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event
)

Expand Down
10 changes: 10 additions & 0 deletions synapse/handlers/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,6 +1177,16 @@ async def filter_evts(events: List[EventBase]) -> List[EventBase]:
# `filtered` rather than the event we retrieved from the datastore.
results["event"] = filtered[0]

# Fetch the aggregations.
aggregations = await self.store.get_bundled_aggregations([results["event"]])
aggregations.update(
await self.store.get_bundled_aggregations(results["events_before"])
)
aggregations.update(
await self.store.get_bundled_aggregations(results["events_after"])
)
results["aggregations"] = aggregations

if results["events_after"]:
last_event_id = results["events_after"][-1].event_id
else:
Expand Down
10 changes: 4 additions & 6 deletions synapse/handlers/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,10 +420,10 @@ async def search(
time_now = self.clock.time_msec()

for context in contexts.values():
context["events_before"] = await self._event_serializer.serialize_events(
context["events_before"] = self._event_serializer.serialize_events(
context["events_before"], time_now
)
context["events_after"] = await self._event_serializer.serialize_events(
context["events_after"] = self._event_serializer.serialize_events(
context["events_after"], time_now
)

Expand All @@ -441,9 +441,7 @@ async def search(
results.append(
{
"rank": rank_map[e.event_id],
"result": (
await self._event_serializer.serialize_event(e, time_now)
),
"result": self._event_serializer.serialize_event(e, time_now),
"context": contexts.get(e.event_id, {}),
}
)
Expand All @@ -457,7 +455,7 @@ async def search(
if state_results:
s = {}
for room_id, state_events in state_results.items():
s[room_id] = await self._event_serializer.serialize_events(
s[room_id] = self._event_serializer.serialize_events(
state_events, time_now
)

Expand Down
16 changes: 8 additions & 8 deletions synapse/rest/admin/rooms.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,7 @@ async def on_GET(
event_ids = await self.store.get_current_state_ids(room_id)
events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec()
room_state = await self._event_serializer.serialize_events(events.values(), now)
room_state = self._event_serializer.serialize_events(events.values(), now)
ret = {"state": room_state}

return HTTPStatus.OK, ret
Expand Down Expand Up @@ -744,22 +744,22 @@ async def on_GET(
)

time_now = self.clock.time_msec()
results["events_before"] = await self._event_serializer.serialize_events(
results["events_before"] = self._event_serializer.serialize_events(
results["events_before"],
time_now,
bundle_aggregations=True,
bundle_aggregations=results["aggregations"],
)
results["event"] = await self._event_serializer.serialize_event(
results["event"] = self._event_serializer.serialize_event(
results["event"],
time_now,
bundle_aggregations=True,
bundle_aggregations=results["aggregations"],
)
results["events_after"] = await self._event_serializer.serialize_events(
results["events_after"] = self._event_serializer.serialize_events(
results["events_after"],
time_now,
bundle_aggregations=True,
bundle_aggregations=results["aggregations"],
)
results["state"] = await self._event_serializer.serialize_events(
results["state"] = self._event_serializer.serialize_events(
results["state"], time_now
)

Expand Down
Loading