diff --git a/synapse/replication/http/state.py b/synapse/replication/http/state.py index 838b7584e56f..7e35b9cd8a64 100644 --- a/synapse/replication/http/state.py +++ b/synapse/replication/http/state.py @@ -19,6 +19,7 @@ from synapse.api.errors import SynapseError from synapse.http.server import HttpServer +from synapse.http.servlet import parse_json_object_from_request from synapse.replication.http._base import ReplicationEndpoint from synapse.types import JsonDict @@ -51,10 +52,15 @@ def __init__(self, hs: "HomeServer"): self._state_handler = hs.get_state_handler() self._events_shard_config = hs.config.worker.events_shard_config self._instance_name = hs.get_instance_name() + self._main_store = hs.get_datastores().main + self._replication = hs.get_replication_data_handler() @staticmethod - async def _serialize_payload(room_id: str) -> JsonDict: # type: ignore[override] - return {} + async def _serialize_payload(room_id: str, local_instance_name: str, unpartial_state_events_position: int) -> JsonDict: # type: ignore[override] + return { + "instance_name": local_instance_name, + "unpartial_state_events_position": unpartial_state_events_position, + } async def _handle_request( # type: ignore[override] self, request: Request, room_id: str @@ -65,9 +71,20 @@ async def _handle_request( # type: ignore[override] 400, "/update_current_state request was routed to the wrong worker" ) + payload = parse_json_object_from_request(request) + await self._replication.wait_for_stream_position( + payload["instance_name"], + "un_partial_stated_event", + payload["unpartial_state_events_position"], + ) + await self._state_handler.update_current_state(room_id) - return 200, {} + return 200, { + "caches_position": self._main_store._cache_id_gen.get_current_token_for_writer( + self._instance_name + ) + } def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index fdfb46ab82ad..0d092e8113a5 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -566,10 +566,21 @@ async def update_current_state(self, room_id: str) -> None: """ writer_instance = self._events_shard_config.get_instance(room_id) if writer_instance != self._instance_name: - await self._update_current_state_client( + res = await self._update_current_state_client( instance_name=writer_instance, room_id=room_id, + local_instance_name=self._instance_name, + unpartial_state_events_position=self.store._un_partial_stated_events_stream_id_gen.get_current_token_for_writer( + self._instance_name + ), ) + + await self.hs.get_replication_data_handler().wait_for_stream_position( + writer_instance, + "caches", + res["caches_position"], + ) + return assert self._storage_controllers.persistence is not None