diff --git a/changelog.d/8342.bugfix b/changelog.d/8342.bugfix new file mode 100644 index 000000000000..786057facb44 --- /dev/null +++ b/changelog.d/8342.bugfix @@ -0,0 +1 @@ +Fix ratelimitng of federation `/send` requests. diff --git a/changelog.d/8343.feature b/changelog.d/8343.feature new file mode 100644 index 000000000000..ccecb22f37f4 --- /dev/null +++ b/changelog.d/8343.feature @@ -0,0 +1 @@ +Add flags to the `/versions` endpoint that includes whether new rooms default to using E2EE. diff --git a/changelog.d/8349.bugfix b/changelog.d/8349.bugfix new file mode 100644 index 000000000000..cf2f531b1487 --- /dev/null +++ b/changelog.d/8349.bugfix @@ -0,0 +1 @@ +Fix a longstanding bug where back pagination over federation could get stuck if it failed to handle a received event. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 218df884b02a..ff00f0b3022e 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -97,10 +97,16 @@ def __init__(self, hs): self.state = hs.get_state_handler() self.device_handler = hs.get_device_handler() + self._federation_ratelimiter = hs.get_federation_ratelimiter() self._server_linearizer = Linearizer("fed_server") self._transaction_linearizer = Linearizer("fed_txn_handler") + # We cache results for transaction with the same ID + self._transaction_resp_cache = ResponseCache( + hs, "fed_txn_handler", timeout_ms=30000 + ) + self.transaction_actions = TransactionActions(self.store) self.registry = hs.get_federation_registry() @@ -135,22 +141,44 @@ async def on_incoming_transaction( request_time = self._clock.time_msec() transaction = Transaction(**transaction_data) + transaction_id = transaction.transaction_id # type: ignore - if not transaction.transaction_id: # type: ignore + if not transaction_id: raise Exception("Transaction missing transaction_id") - logger.debug("[%s] Got transaction", transaction.transaction_id) # type: ignore + logger.debug("[%s] Got transaction", transaction_id) - # use a linearizer to ensure that we don't process the same transaction - # multiple times in parallel. - with ( - await self._transaction_linearizer.queue( - (origin, transaction.transaction_id) # type: ignore - ) - ): - result = await self._handle_incoming_transaction( - origin, transaction, request_time - ) + # We wrap in a ResponseCache so that we de-duplicate retried + # transactions. + return await self._transaction_resp_cache.wrap( + (origin, transaction_id), + self._on_incoming_transaction_inner, + origin, + transaction, + request_time, + ) + + async def _on_incoming_transaction_inner( + self, origin: str, transaction: Transaction, request_time: int + ) -> Tuple[int, Dict[str, Any]]: + # Use a linearizer to ensure that transactions from a remote are + # processed in order. + with await self._transaction_linearizer.queue(origin): + # We rate limit here *after* we've queued up the incoming requests, + # so that we don't fill up the ratelimiter with blocked requests. + # + # This is important as the ratelimiter allows N concurrent requests + # at a time, and only starts ratelimiting if there are more requests + # than that being processed at a time. If we queued up requests in + # the linearizer/response cache *after* the ratelimiting then those + # queued up requests would count as part of the allowed limit of N + # concurrent requests. + with self._federation_ratelimiter.ratelimit(origin) as d: + await d + + result = await self._handle_incoming_transaction( + origin, transaction, request_time + ) return result diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 794b3098e9ca..bbf111955669 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -46,7 +46,6 @@ ) from synapse.server import HomeServer from synapse.types import ThirdPartyInstanceID, get_domain_from_id -from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.versionstring import get_version_string logger = logging.getLogger(__name__) @@ -73,9 +72,7 @@ def __init__(self, hs, servlet_groups=None): super(TransportLayerServer, self).__init__(hs, canonical_json=False) self.authenticator = Authenticator(hs) - self.ratelimiter = FederationRateLimiter( - self.clock, config=hs.config.rc_federation - ) + self.ratelimiter = hs.get_federation_ratelimiter() self.register_servlets() @@ -273,6 +270,8 @@ class BaseFederationServlet: PREFIX = FEDERATION_V1_PREFIX # Allows specifying the API version + RATELIMIT = True # Whether to rate limit requests or not + def __init__(self, handler, authenticator, ratelimiter, server_name): self.handler = handler self.authenticator = authenticator @@ -336,7 +335,7 @@ async def new_func(request, *args, **kwargs): ) with scope: - if origin: + if origin and self.RATELIMIT: with ratelimiter.ratelimit(origin) as d: await d if request._disconnected: @@ -373,6 +372,10 @@ def register(self, server): class FederationSendServlet(BaseFederationServlet): PATH = "/send/(?P[^/]*)/?" + # We ratelimit manually in the handler as we queue up the requests and we + # don't want to fill up the ratelimiter with blocked requests. + RATELIMIT = False + def __init__(self, handler, server_name, **kwargs): super(FederationSendServlet, self).__init__( handler, server_name=server_name, **kwargs diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 2831932edd48..b892a6f97e0d 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -920,15 +920,26 @@ async def backfill(self, dest, room_id, limit, extremities): return events - async def maybe_backfill(self, room_id, current_depth): + async def maybe_backfill( + self, room_id: str, current_depth: int, limit: int + ) -> bool: """Checks the database to see if we should backfill before paginating, and if so do. + + Args: + room_id + current_depth: The depth from which we're paginating from. This is + used to decide if we should backfill and what extremities to + use. + limit: The number of events that the pagination request will + return. This is used as part of the heuristic to decide if we + should back paginate. """ extremities = await self.store.get_oldest_events_with_depth_in_room(room_id) if not extremities: logger.debug("Not backfilling as no extremeties found.") - return + return False # We only want to paginate if we can actually see the events we'll get, # as otherwise we'll just spend a lot of resources to get redacted @@ -981,16 +992,54 @@ async def maybe_backfill(self, room_id, current_depth): sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1])) max_depth = sorted_extremeties_tuple[0][1] + # If we're approaching an extremity we trigger a backfill, otherwise we + # no-op. + # + # We chose twice the limit here as then clients paginating backwards + # will send pagination requests that trigger backfill at least twice + # using the most recent extremity before it gets removed (see below). We + # chose more than one times the limit in case of failure, but choosing a + # much larger factor will result in triggering a backfill request much + # earlier than necessary. + if current_depth - 2 * limit > max_depth: + logger.debug( + "Not backfilling as we don't need to. %d < %d - 2 * %d", + max_depth, + current_depth, + limit, + ) + return False + + logger.debug( + "room_id: %s, backfill: current_depth: %s, max_depth: %s, extrems: %s", + room_id, + current_depth, + max_depth, + sorted_extremeties_tuple, + ) + + # We ignore extremities that have a greater depth than our current depth + # as: + # 1. we don't really care about getting events that have happened + # before our current position; and + # 2. we have likely previously tried and failed to backfill from that + # extremity, so to avoid getting "stuck" requesting the same + # backfill repeatedly we drop those extremities. + filtered_sorted_extremeties_tuple = [ + t for t in sorted_extremeties_tuple if int(t[1]) <= current_depth + ] + + # However, we need to check that the filtered extremities are non-empty. + # If they are empty then either we can a) bail or b) still attempt to + # backill. We opt to try backfilling anyway just in case we do get + # relevant events. + if filtered_sorted_extremeties_tuple: + sorted_extremeties_tuple = filtered_sorted_extremeties_tuple + # We don't want to specify too many extremities as it causes the backfill # request URI to be too long. extremities = dict(sorted_extremeties_tuple[:5]) - if current_depth > max_depth: - logger.debug( - "Not backfilling as we don't need to. %d < %d", max_depth, current_depth - ) - return - # Now we need to decide which hosts to hit first. # First we try hosts that are already in the room diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index d929a68f7d8c..a0b3bdb5e0c3 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -358,9 +358,9 @@ async def get_messages( # if we're going backwards, we might need to backfill. This # requires that we have a topo token. if room_token.topological: - max_topo = room_token.topological + curr_topo = room_token.topological else: - max_topo = await self.store.get_max_topological_token( + curr_topo = await self.store.get_current_topological_token( room_id, room_token.stream ) @@ -379,13 +379,13 @@ async def get_messages( leave_token = RoomStreamToken.parse(leave_token_str) assert leave_token.topological is not None - if leave_token.topological < max_topo: + if leave_token.topological < curr_topo: from_token = from_token.copy_and_replace( "room_key", leave_token ) await self.hs.get_handlers().federation_handler.maybe_backfill( - room_id, max_topo + room_id, curr_topo, limit=pagin_config.limit, ) to_room_key = None diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index 58ec5a694aca..24a58af0617b 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -19,6 +19,7 @@ import logging import re +from synapse.api.constants import RoomCreationPreset from synapse.http.servlet import RestServlet logger = logging.getLogger(__name__) @@ -31,6 +32,20 @@ def __init__(self, hs): super(VersionsRestServlet, self).__init__() self.config = hs.config + # Calculate these once since they shouldn't change after start-up. + self.e2ee_forced_public = ( + RoomCreationPreset.PUBLIC_CHAT + in self.config.encryption_enabled_by_default_for_room_presets + ) + self.e2ee_forced_private = ( + RoomCreationPreset.PRIVATE_CHAT + in self.config.encryption_enabled_by_default_for_room_presets + ) + self.e2ee_forced_trusted_private = ( + RoomCreationPreset.TRUSTED_PRIVATE_CHAT + in self.config.encryption_enabled_by_default_for_room_presets + ) + def on_GET(self, request): return ( 200, @@ -65,6 +80,10 @@ def on_GET(self, request): "m.lazy_load_members": True, # Implements additional endpoints as described in MSC2666 "uk.half-shot.msc2666": True, + # Whether new rooms will be set to encrypted or not (based on presets). + "io.element.e2ee_forced.public": self.e2ee_forced_public, + "io.element.e2ee_forced.private": self.e2ee_forced_private, + "io.element.e2ee_forced.trusted_private": self.e2ee_forced_trusted_private, }, }, ) diff --git a/synapse/server.py b/synapse/server.py index 9055b97ac317..5e3752c3334f 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -114,6 +114,7 @@ from synapse.types import DomainSpecificString from synapse.util import Clock from synapse.util.distributor import Distributor +from synapse.util.ratelimitutils import FederationRateLimiter from synapse.util.stringutils import random_string logger = logging.getLogger(__name__) @@ -642,6 +643,10 @@ def get_replication_data_handler(self) -> ReplicationDataHandler: def get_replication_streams(self) -> Dict[str, Stream]: return {stream.NAME: stream(self) for stream in STREAMS_MAP.values()} + @cache_in_self + def get_federation_ratelimiter(self) -> FederationRateLimiter: + return FederationRateLimiter(self.clock, config=self.config.rc_federation) + async def remove_pusher(self, app_id: str, push_key: str, user_id: str): return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id) diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 7dbe11513b3c..6933d05865a4 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -640,23 +640,20 @@ async def get_topological_token_for_event(self, event_id: str) -> str: ) return "t%d-%d" % (row["topological_ordering"], row["stream_ordering"]) - async def get_max_topological_token(self, room_id: str, stream_key: int) -> int: - """Get the max topological token in a room before the given stream + async def get_current_topological_token(self, room_id: str, stream_key: int) -> int: + """Gets the topological token in a room after or at the given stream ordering. Args: room_id stream_key - - Returns: - The maximum topological token. """ sql = ( - "SELECT coalesce(max(topological_ordering), 0) FROM events" - " WHERE room_id = ? AND stream_ordering < ?" + "SELECT coalesce(MIN(topological_ordering), 0) FROM events" + " WHERE room_id = ? AND stream_ordering >= ?" ) row = await self.db_pool.execute( - "get_max_topological_token", None, sql, room_id, stream_key + "get_current_topological_token", None, sql, room_id, stream_key ) return row[0][0] if row else 0