Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Faster joins: Support for calling /federation/v1/state #12013

Merged
merged 12 commits into from
Feb 22, 2022
1 change: 1 addition & 0 deletions changelog.d/12013.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server.
10 changes: 9 additions & 1 deletion synapse/federation/federation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ async def _check_sigs_and_hash(
) -> EventBase:
"""Checks that event is correctly signed by the sending server.
Also checks the content hash, and redacts the event if there is a mismatch.
Also runs the event through the spam checker; if it fails, redacts the event
and flags it as soft-failed.
Args:
room_version: The room version of the PDU
pdu: the event to be checked
Expand All @@ -55,7 +60,10 @@ async def _check_sigs_and_hash(
* the original event if the checks pass
* a redacted version of the event (if the signature
matched but the hash did not)
* throws a SynapseError if the signature check failed."""
Raises:
SynapseError if the signature check failed.
"""
try:
await _check_sigs_on_pdu(self.keyring, room_version, pdu)
except SynapseError as e:
Expand Down
93 changes: 81 additions & 12 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,26 +413,90 @@ async def get_room_state_ids(

return state_event_ids, auth_event_ids

async def get_room_state(
self,
destination: str,
room_id: str,
event_id: str,
room_version: RoomVersion,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Calls the /state endpoint to fetch the state at a particular point
in the room.

Any invalid events (those with incorrect or unverifiable signatures or hashes)
are filtered out from the response, and any duplicate events are removed.

(Size limits and other event-format checks are *not* performed.)

Note that the result is not ordered, so callers must be careful to process
the events in an order that handles dependencies.

Returns:
a tuple of (state events, auth events)
"""
result = await self.transport_layer.get_room_state(
room_version,
destination,
room_id,
event_id,
)
state_events = result.state
auth_events = result.auth_events

# we may as well filter out any duplicates from the response, to save
# processing them multiple times. (In particular, events may be present in
# `auth_events` as well as `state`, which is redundant).
#
# We don't rely on the sort order of the events, so we can just stick them
# in a dict.
state_event_map = {event.event_id: event for event in state_events}
auth_event_map = {
event.event_id: event
for event in auth_events
if event.event_id not in state_event_map
}

logger.info(
"Processing from /state: %d state events, %d auth events",
len(state_event_map),
len(auth_event_map),
)

valid_auth_events = await self._check_sigs_and_hash_and_fetch(
destination, auth_event_map.values(), room_version
)

valid_state_events = await self._check_sigs_and_hash_and_fetch(
destination, state_event_map.values(), room_version
)

return valid_state_events, valid_auth_events

async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
pdus: Collection[EventBase],
room_version: RoomVersion,
) -> List[EventBase]:
"""Takes a list of PDUs and checks the signatures and hashes of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
that PDU.
"""Checks the signatures and hashes of a list of events.

If a PDU fails its signature check then we check if we have it in
the database, and if not then request it from the sender's server (if that
is different from `origin`). If that still fails, the event is omitted from
the returned list.

If a PDU fails its content hash check then it is redacted.

The given list of PDUs are not modified, instead the function returns
Also runs each event through the spam checker; if it fails, redacts the event
and flags it as soft-failed.

The given list of PDUs are not modified; instead the function returns
a new list.

Args:
origin
pdu
room_version
origin: The server that sent us these events
pdus: The events to be checked
room_version: the version of the room these events are in

Returns:
A list of PDUs that have valid signatures and hashes.
Expand Down Expand Up @@ -463,11 +527,16 @@ async def _check_sigs_and_hash_and_fetch_one(
origin: str,
room_version: RoomVersion,
) -> Optional[EventBase]:
"""Takes a PDU and checks its signatures and hashes. If the PDU fails
its signature check then we check if we have it in the database and if
not then request if from the originating server of that PDU.
"""Takes a PDU and checks its signatures and hashes.

If the PDU fails its signature check then we check if we have it in the
database; if not, we then request it from sender's server (if that is not the
same as `origin`). If that still fails, we return None.

If the PDU fails its content hash check, it is redacted.

If then PDU fails its content hash check then it is redacted.
Also runs the event through the spam checker; if it fails, redacts the event
and flags it as soft-failed.

Args:
origin
Expand Down
70 changes: 67 additions & 3 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,12 @@ def __init__(self, hs):
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
) -> JsonDict:
"""Requests all state for a given room from the given server at the
given event. Returns the state's event_id's
"""Requests the IDs of all state for a given room at the given event.

Args:
destination: The host name of the remote homeserver we want
to get the state from.
context: The name of the context we want the state of
room_id: the room we want the state of
event_id: The event we want the context at.

Returns:
Expand All @@ -86,6 +85,29 @@ async def get_room_state_ids(
try_trailing_slash_on_400=True,
)

async def get_room_state(
self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
) -> "StateRequestResponse":
"""Requests the full state for a given room at the given event.

Args:
room_version: the version of the room (required to build the event objects)
destination: The host name of the remote homeserver we want
to get the state from.
room_id: the room we want the state of
event_id: The event we want the context at.

Returns:
Results in a dict received from the remote homeserver.
"""
path = _create_v1_path("/state/%s", room_id)
return await self.client.get_json(
destination,
path=path,
args={"event_id": event_id},
parser=_StateParser(room_version),
)

async def get_event(
self, destination: str, event_id: str, timeout: Optional[int] = None
) -> JsonDict:
Expand Down Expand Up @@ -1272,6 +1294,14 @@ class SendJoinResponse:
event: Optional[EventBase] = None


@attr.s(slots=True, auto_attribs=True)
class StateRequestResponse:
"""The parsed response of a `/state` request."""

auth_events: List[EventBase]
state: List[EventBase]


@ijson.coroutine
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
Expand Down Expand Up @@ -1355,3 +1385,37 @@ def finish(self) -> SendJoinResponse:
self._response.event_dict, self._room_version
)
return self._response


class _StateParser(ByteParser[StateRequestResponse]):
"""A parser for the response to `/state` requests.

Args:
room_version: The version of the room.
"""

CONTENT_TYPE = "application/json"

def __init__(self, room_version: RoomVersion):
self._response = StateRequestResponse([], [])
self._room_version = room_version
self._coros = [
ijson.items_coro(
_event_list_parser(room_version, self._response.state),
"pdus.item",
use_float=True,
),
ijson.items_coro(
_event_list_parser(room_version, self._response.auth_events),
"auth_chain.item",
use_float=True,
),
]

def write(self, data: bytes) -> int:
for c in self._coros:
c.send(data)
return len(data)

def finish(self) -> StateRequestResponse:
return self._response
50 changes: 49 additions & 1 deletion synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,6 +958,7 @@ async def post_json(
)
return body

@overload
async def get_json(
self,
destination: str,
Expand All @@ -967,7 +968,38 @@ async def get_json(
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
max_response_size: Optional[int] = None,
) -> Union[JsonDict, list]:
...

@overload
async def get_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = ...,
retry_on_dns_fail: bool = ...,
timeout: Optional[int] = ...,
ignore_backoff: bool = ...,
try_trailing_slash_on_400: bool = ...,
parser: ByteParser[T] = ...,
max_response_size: Optional[int] = ...,
) -> T:
...

async def get_json(
self,
destination: str,
path: str,
args: Optional[QueryArgs] = None,
retry_on_dns_fail: bool = True,
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Optional[ByteParser] = None,
max_response_size: Optional[int] = None,
):
"""GETs some json from the given host homeserver and path

Args:
Expand All @@ -992,6 +1024,13 @@ async def get_json(
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.

parser: The parser to use to decode the response. Defaults to
parsing as JSON.

max_response_size: The maximum size to read from the response. If None,
uses the default.

Returns:
Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
Expand Down Expand Up @@ -1026,8 +1065,17 @@ async def get_json(
else:
_sec_timeout = self.default_timeout

if parser is None:
parser = JsonParser()

body = await _handle_response(
self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
self.reactor,
_sec_timeout,
request,
response,
start_ms,
parser=parser,
max_response_size=max_response_size,
)

return body
Expand Down
Loading