Skip to content
This repository was archived by the owner on Apr 26, 2024. It is now read-only.

Improve event caching code #10119

Merged
merged 16 commits into from
Aug 4, 2021
Merged
1 change: 1 addition & 0 deletions changelog.d/10119.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve event caching mechanism to avoid having multiple copies of an event in memory at a time.
144 changes: 105 additions & 39 deletions synapse/storage/databases/main/events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import logging
import threading
from collections import namedtuple
from typing import (
Collection,
Container,
Expand All @@ -27,6 +26,7 @@
overload,
)

import attr
from constantly import NamedConstant, Names
from typing_extensions import Literal

Expand All @@ -42,7 +42,11 @@
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.logging.context import (
PreserveLoggingContext,
current_context,
make_deferred_yieldable,
)
from synapse.metrics.background_process_metrics import (
run_as_background_process,
wrap_as_background_process,
Expand All @@ -56,6 +60,8 @@
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.storage.util.sequence import build_sequence_generator
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.lrucache import LruCache
from synapse.util.iterutils import batch_iter
Expand All @@ -74,7 +80,10 @@
EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events


_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
@attr.s(slots=True, auto_attribs=True)
class _EventCacheEntry:
event: EventBase
redacted_event: Optional[EventBase]


class EventRedactBehaviour(Names):
Expand Down Expand Up @@ -161,6 +170,13 @@ def __init__(self, database: DatabasePool, db_conn, hs):
max_size=hs.config.caches.event_cache_size,
)

# Map from event ID to a deferred that will result in a map from event
# ID to cache entry. Note that the returned dict may not have the
# requested event in it if the event isn't in the DB.
self._current_event_fetches: Dict[
str, ObservableDeferred[Dict[str, _EventCacheEntry]]
] = {}

self._event_fetch_lock = threading.Condition()
self._event_fetch_list = []
self._event_fetch_ongoing = 0
Expand Down Expand Up @@ -476,7 +492,9 @@ async def get_events_as_list(

return events

async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
async def _get_events_from_cache_or_db(
self, event_ids: Iterable[str], allow_rejected: bool = False
) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the cache or the database.

If events are pulled from the database, they will be cached for future lookups.
Expand All @@ -485,53 +503,107 @@ async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):

Args:

event_ids (Iterable[str]): The event_ids of the events to fetch
event_ids: The event_ids of the events to fetch

allow_rejected (bool): Whether to include rejected events. If False,
allow_rejected: Whether to include rejected events. If False,
rejected events are omitted from the response.

Returns:
Dict[str, _EventCacheEntry]:
map from event id to result
map from event id to result
"""
event_entry_map = self._get_events_from_cache(
event_ids, allow_rejected=allow_rejected
event_ids,
)

missing_events_ids = [e for e in event_ids if e not in event_entry_map]
missing_events_ids = {e for e in event_ids if e not in event_entry_map}

# We now look up if we're already fetching some of the events in the DB,
# if so we wait for those lookups to finish instead of pulling the same
# events out of the DB multiple times.
already_fetching: Dict[str, defer.Deferred] = {}

for event_id in missing_events_ids:
deferred = self._current_event_fetches.get(event_id)
if deferred is not None:
# We're already pulling the event out of the DB. Add the deferred
# to the collection of deferreds to wait on.
already_fetching[event_id] = deferred.observe()

missing_events_ids.difference_update(already_fetching)

if missing_events_ids:
log_ctx = current_context()
log_ctx.record_event_fetch(len(missing_events_ids))

# Add entries to `self._current_event_fetches` for each event we're
# going to pull from the DB. We use a single deferred that resolves
# to all the events we pulled from the DB (this will result in this
# function returning more events than requested, but that can happen
# already due to `_get_events_from_db`).
fetching_deferred: ObservableDeferred[
Dict[str, _EventCacheEntry]
] = ObservableDeferred(defer.Deferred())
for event_id in missing_events_ids:
self._current_event_fetches[event_id] = fetching_deferred

# Note that _get_events_from_db is also responsible for turning db rows
# into FrozenEvents (via _get_event_from_row), which involves seeing if
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
missing_events = await self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)
try:
missing_events = await self._get_events_from_db(
missing_events_ids,
)

event_entry_map.update(missing_events)
event_entry_map.update(missing_events)
except Exception as e:
with PreserveLoggingContext():
fetching_deferred.errback(e)
raise e
finally:
# Ensure that we mark these events as no longer being fetched.
for event_id in missing_events_ids:
self._current_event_fetches.pop(event_id, None)
Comment on lines +565 to +567
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if doing this after fetching_deferred.errback could cause races. I can't really think how it could, though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the only thing that can happen is that another request waits on the deferred after its already been resolved, which shouldn't be an issue as awaiting on the deferred will just return immediately?


with PreserveLoggingContext():
fetching_deferred.callback(missing_events)

if already_fetching:
# Wait for the other event requests to finish and add their results
# to ours.
results = await make_deferred_yieldable(
defer.gatherResults(
already_fetching.values(),
consumeErrors=True,
)
).addErrback(unwrapFirstError)

for result in results:
event_entry_map.update(result)

if not allow_rejected:
event_entry_map = {
event_id: entry
for event_id, entry in event_entry_map.items()
if not entry.event.rejected_reason
}

return event_entry_map

def _invalidate_get_event_cache(self, event_id):
self._get_event_cache.invalidate((event_id,))

def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
"""Fetch events from the caches
def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, _EventCacheEntry]:
"""Fetch events from the caches.

Args:
events (Iterable[str]): list of event_ids to fetch
allow_rejected (bool): Whether to return events that were rejected
update_metrics (bool): Whether to update the cache hit ratio metrics
May return rejected events.

Returns:
dict of event_id -> _EventCacheEntry for each event_id in cache. If
allow_rejected is `False` then there will still be an entry but it
will be `None`
Args:
events: list of event_ids to fetch
update_metrics: Whether to update the cache hit ratio metrics
"""
event_map = {}

Expand All @@ -542,10 +614,7 @@ def _get_events_from_cache(self, events, allow_rejected, update_metrics=True):
if not ret:
continue

if allow_rejected or not ret.event.rejected_reason:
event_map[event_id] = ret
else:
event_map[event_id] = None
event_map[event_id] = ret

return event_map

Expand Down Expand Up @@ -672,23 +741,23 @@ def fire(evs, exc):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)

async def _get_events_from_db(self, event_ids, allow_rejected=False):
async def _get_events_from_db(
self, event_ids: Iterable[str]
) -> Dict[str, _EventCacheEntry]:
"""Fetch a bunch of events from the database.

May return rejected events.

Returned events will be added to the cache for future lookups.

Unknown events are omitted from the response.

Args:
event_ids (Iterable[str]): The event_ids of the events to fetch

allow_rejected (bool): Whether to include rejected events. If False,
rejected events are omitted from the response.
event_ids: The event_ids of the events to fetch

Returns:
Dict[str, _EventCacheEntry]:
map from event id to result. May return extra events which
weren't asked for.
map from event id to result. May return extra events which
weren't asked for.
"""
fetched_events = {}
events_to_fetch = event_ids
Expand Down Expand Up @@ -717,9 +786,6 @@ async def _get_events_from_db(self, event_ids, allow_rejected=False):

rejected_reason = row["rejected_reason"]

if not allow_rejected and rejected_reason:
continue

# If the event or metadata cannot be parsed, log the error and act
# as if the event is unknown.
try:
Expand Down
6 changes: 2 additions & 4 deletions synapse/storage/databases/main/roommember.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,14 +629,12 @@ async def _get_joined_users_from_context(
# We don't update the event cache hit ratio as it completely throws off
# the hit ratio counts. After all, we don't populate the cache if we
# miss it here
event_map = self._get_events_from_cache(
member_event_ids, allow_rejected=False, update_metrics=False
)
event_map = self._get_events_from_cache(member_event_ids, update_metrics=False)

missing_member_event_ids = []
for event_id in member_event_ids:
ev_entry = event_map.get(event_id)
if ev_entry:
if ev_entry and not ev_entry.event.rejected_reason:
if ev_entry.event.membership == Membership.JOIN:
users_in_room[ev_entry.event.state_key] = ProfileInfo(
display_name=ev_entry.event.content.get("displayname", None),
Expand Down
50 changes: 50 additions & 0 deletions tests/storage/databases/main/test_events_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
import json

from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client.v1 import login, room
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.util.async_helpers import yieldable_gather_results

from tests import unittest

Expand Down Expand Up @@ -94,3 +97,50 @@ def test_query_via_event_cache(self):
res = self.get_success(self.store.have_seen_events("room1", ["event10"]))
self.assertEquals(res, {"event10"})
self.assertEquals(ctx.get_resource_usage().db_txn_count, 0)


class EventCacheTestCase(unittest.HomeserverTestCase):
"""Test that the various layers of event cache works."""

servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]

def prepare(self, reactor, clock, hs):
self.store: EventsWorkerStore = hs.get_datastore()

self.user = self.register_user("user", "pass")
self.token = self.login(self.user, "pass")

self.room = self.helper.create_room_as(self.user, tok=self.token)

res = self.helper.send(self.room, tok=self.token)
self.event_id = res["event_id"]

# Reset the event cache so the tests start with it empty
self.store._get_event_cache.clear()

def test_simple(self):
"""Test that we cache events that we pull from the DB."""

with LoggingContext("test") as ctx:
self.get_success(self.store.get_event(self.event_id))

# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)

def test_dedupe(self):
"""Test that if we request the same event multiple times we only pull it
out once.
"""

with LoggingContext("test") as ctx:
d = yieldable_gather_results(
self.store.get_event, [self.event_id, self.event_id]
)
self.get_success(d)

# We should have fetched the event from the DB
self.assertEqual(ctx.get_resource_usage().evt_db_fetch_count, 1)