diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py index c751edac70e7..cc96a33b04d2 100644 --- a/synapse/storage/persist_events.py +++ b/synapse/storage/persist_events.py @@ -52,7 +52,7 @@ StateMap, get_domain_from_id, ) -from synapse.util.async_helpers import ObservableDeferred +from synapse.util.async_helpers import ObservableDeferred, concurrently_execute from synapse.util.metrics import Measure logger = logging.getLogger(__name__) @@ -135,25 +135,24 @@ def __init__( self._currently_persisting_rooms: Set[str] = set() self._per_item_callback = per_item_callback - def add_to_queue(self, room_id, events_and_contexts, backfilled) -> Deferred: + async def add_to_queue( + self, + room_id: str, + events_and_contexts: Iterable[Tuple[EventBase, EventContext]], + backfilled: bool, + ) -> _PersistResult: """Add events to the queue, with the given persist_event options. If we are not already processing events in this room, starts off a background process to to so, calling the per_item_callback for each item. - NB: due to the normal usage pattern of this method, it does *not* - follow the synapse logcontext rules, and leaves the logcontext in - place whether or not the returned deferred is ready. - Args: room_id (str): events_and_contexts (list[(EventBase, EventContext)]): backfilled (bool): Returns: - defer.Deferred: a deferred which will resolve once the events are - persisted. Runs its callbacks in the sentinel logcontext. The result - is the same as that returned by the `_per_item_callback` passed to + the result returned by the `_per_item_callback` passed to `__init__`. """ queue = self._event_persist_queues.setdefault(room_id, deque()) @@ -175,7 +174,7 @@ def add_to_queue(self, room_id, events_and_contexts, backfilled) -> Deferred: end_item.events_and_contexts.extend(events_and_contexts) self._handle_queue(room_id) - return end_item.deferred.observe() + return await make_deferred_yieldable(end_item.deferred.observe()) def _handle_queue(self, room_id): """Attempts to handle the queue for a room if not already being handled. @@ -278,22 +277,20 @@ async def persist_events( for event, ctx in events_and_contexts: partitioned.setdefault(event.room_id, []).append((event, ctx)) - deferreds = [] - for room_id, evs_ctxs in partitioned.items(): - d = self._event_persist_queue.add_to_queue( + async def enqueue(item): + room_id, evs_ctxs = item + return await self._event_persist_queue.add_to_queue( room_id, evs_ctxs, backfilled=backfilled ) - deferreds.append(d) - # Each deferred returns a map from event ID to existing event ID if the - # event was deduplicated. (The dict may also include other entries if + ret_vals = concurrently_execute(enqueue, partitioned.items(), 20) + + # Each call to add_to_queue returns a map from event ID to existing event ID if + # the event was deduplicated. (The dict may also include other entries if # the event was persisted in a batch with other events). # - # Since we use `defer.gatherResults` we need to merge the returned list + # Since we use `concurrently_execute` we need to merge the returned list # of dicts into one. - ret_vals = await make_deferred_yieldable( - defer.gatherResults(deferreds, consumeErrors=True) - ) replaced_events: Dict[str, str] = {} for d in ret_vals: replaced_events.update(d) @@ -321,14 +318,12 @@ async def persist_event( event if it was deduplicated due to an existing event matching the transaction ID. """ - deferred = self._event_persist_queue.add_to_queue( - event.room_id, [(event, context)], backfilled=backfilled - ) - - # The deferred returns a map from event ID to existing event ID if the + # add_to_queue returns a map from event ID to existing event ID if the # event was deduplicated. (The dict may also include other entries if # the event was persisted in a batch with other events.) - replaced_events = await make_deferred_yieldable(deferred) + replaced_events = await self._event_persist_queue.add_to_queue( + event.room_id, [(event, context)], backfilled=backfilled + ) replaced_event = replaced_events.get(event.event_id) if replaced_event: event = await self.main_store.get_event(replaced_event)