Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Factor out an is_mine_server_name method #15542

Merged
merged 6 commits into from
May 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/15542.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Factor out an `is_mine_server_name` method.
4 changes: 2 additions & 2 deletions synapse/api/auth_blocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, hs: "HomeServer"):
self._mau_limits_reserved_threepids = (
hs.config.server.mau_limits_reserved_threepids
)
self._server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self._track_appservice_user_ips = hs.config.appservice.track_appservice_user_ips

async def check_auth_blocking(
Expand Down Expand Up @@ -77,7 +77,7 @@ async def check_auth_blocking(
if requester:
if requester.authenticated_entity.startswith("@"):
user_id = requester.authenticated_entity
elif requester.authenticated_entity == self._server_name:
elif self._is_mine_server_name(requester.authenticated_entity):
# We never block the server from doing actions on behalf of
# users.
return
Expand Down
4 changes: 2 additions & 2 deletions synapse/crypto/keyring.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def __init__(
process_batch_callback=self._inner_fetch_key_requests,
)

self._hostname = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name

# build a FetchKeyResult for each of our own keys, to shortcircuit the
# fetcher.
Expand Down Expand Up @@ -277,7 +277,7 @@ async def process_request(self, verify_request: VerifyJsonRequest) -> None:

# If we are the originating server, short-circuit the key-fetch for any keys
# we already have
if verify_request.server_name == self._hostname:
if self._is_mine_server_name(verify_request.server_name):
for key_id in verify_request.key_ids:
if key_id in self._local_verify_keys:
found_keys[key_id] = self._local_verify_keys[key_id]
Expand Down
2 changes: 1 addition & 1 deletion synapse/federation/federation_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class FederationBase:
def __init__(self, hs: "HomeServer"):
self.hs = hs

self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self.keyring = hs.get_keyring()
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self.store = hs.get_datastores().main
Expand Down
4 changes: 2 additions & 2 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,7 +854,7 @@ async def _try_destination_list(

for destination in destinations:
# We don't want to ask our own server for information we don't have
if destination == self.server_name:
if self._is_mine_server_name(destination):
continue

try:
Expand Down Expand Up @@ -1536,7 +1536,7 @@ async def forward_third_party_invite(
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
) -> None:
for destination in destinations:
if destination == self.server_name:
if self._is_mine_server_name(destination):
continue

try:
Expand Down
3 changes: 2 additions & 1 deletion synapse/federation/federation_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class FederationServer(FederationBase):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)

self.server_name = hs.hostname
self.handler = hs.get_federation_handler()
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self._federation_event_handler = hs.get_federation_event_handler()
Expand Down Expand Up @@ -942,7 +943,7 @@ async def _on_send_membership_event(
authorising_server = get_domain_from_id(
event.content[EventContentFields.AUTHORISING_USER]
)
if authorising_server != self.server_name:
if not self._is_mine_server_name(authorising_server):
raise SynapseError(
400,
f"Cannot authorise request from resident server: {authorising_server}",
Expand Down
3 changes: 2 additions & 1 deletion synapse/federation/send_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name

# We may have multiple federation sender instances, so we need to track
# their positions separately.
Expand Down Expand Up @@ -198,7 +199,7 @@ def build_and_send_edu(
key: Optional[Hashable] = None,
) -> None:
"""As per FederationSender"""
if destination == self.server_name:
if self.is_mine_server_name(destination):
logger.info("Not sending EDU to ourselves")
return

Expand Down
11 changes: 6 additions & 5 deletions synapse/federation/sender/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ def __init__(self, hs: "HomeServer"):

self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name

self._presence_router: Optional["PresenceRouter"] = None
self._transaction_manager = TransactionManager(hs)
Expand Down Expand Up @@ -766,7 +767,7 @@ async def send_read_receipt(self, receipt: ReadReceipt) -> None:
domains = [
d
for d in domains_set
if d != self.server_name
if not self.is_mine_server_name(d)
and self._federation_shard_config.should_handle(self._instance_name, d)
]
if not domains:
Expand Down Expand Up @@ -832,7 +833,7 @@ def send_presence_to_destinations(
assert self.is_mine_id(state.user_id)

for destination in destinations:
if destination == self.server_name:
if self.is_mine_server_name(destination):
continue
if not self._federation_shard_config.should_handle(
self._instance_name, destination
Expand Down Expand Up @@ -860,7 +861,7 @@ def build_and_send_edu(
content: content of EDU
key: clobbering key for this edu
"""
if destination == self.server_name:
if self.is_mine_server_name(destination):
logger.info("Not sending EDU to ourselves")
return

Expand Down Expand Up @@ -897,7 +898,7 @@ def send_edu(self, edu: Edu, key: Optional[Hashable]) -> None:
queue.send_edu(edu)

def send_device_messages(self, destination: str, immediate: bool = True) -> None:
if destination == self.server_name:
if self.is_mine_server_name(destination):
logger.warning("Not sending device update to ourselves")
return

Expand All @@ -919,7 +920,7 @@ def wake_destination(self, destination: str) -> None:
might have come back.
"""

if destination == self.server_name:
if self.is_mine_server_name(destination):
logger.warning("Not waking up ourselves")
return

Expand Down
4 changes: 2 additions & 2 deletions synapse/federation/transport/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ class TransportLayerClient:
"""Sends federation HTTP requests to other servers"""

def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
self._is_mine_server_name = hs.is_mine_server_name

async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
Expand Down Expand Up @@ -235,7 +235,7 @@ async def send_transaction(
transaction.transaction_id,
)

if transaction.destination == self.server_name:
if self._is_mine_server_name(transaction.destination):
raise RuntimeError("Transport layer cannot send to itself!")

# FIXME: This is only used by the tests. The actual json sent is
Expand Down
5 changes: 4 additions & 1 deletion synapse/federation/transport/server/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self.store = hs.get_datastores().main
self.federation_domain_whitelist = (
hs.config.federation.federation_domain_whitelist
Expand Down Expand Up @@ -100,7 +101,9 @@ async def authenticate_request(
json_request["signatures"].setdefault(origin, {})[key] = sig

# if the origin_server sent a destination along it needs to match our own server_name
if destination is not None and destination != self.server_name:
if destination is not None and not self._is_mine_server_name(
destination
):
raise AuthenticationError(
HTTPStatus.UNAUTHORIZED,
"Destination mismatch in auth header",
Expand Down
5 changes: 3 additions & 2 deletions synapse/handlers/event_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
)
from synapse.events import EventBase
from synapse.events.builder import EventBuilder
from synapse.types import StateMap, StrCollection, get_domain_from_id
from synapse.types import StateMap, StrCollection

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand All @@ -47,6 +47,7 @@ def __init__(self, hs: "HomeServer"):
self._store = hs.get_datastores().main
self._state_storage_controller = hs.get_storage_controllers().state
self._server_name = hs.hostname
self._is_mine_id = hs.is_mine_id

async def check_auth_rules_from_context(
self,
Expand Down Expand Up @@ -247,7 +248,7 @@ async def check_restricted_join_rules(
if not await self.is_user_in_rooms(allowed_rooms, user_id):
# If this is a remote request, the user might be in an allowed room
# that we do not know about.
if get_domain_from_id(user_id) != self._server_name:
if not self._is_mine_id(user_id):
for room_id in allowed_rooms:
if not await self._store.is_host_joined(room_id, self._server_name):
raise SynapseError(
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/federation.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.keyring = hs.get_keyring()
self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
self.event_creation_handler = hs.get_event_creation_handler()
self.event_builder_factory = hs.get_event_builder_factory()
Expand Down Expand Up @@ -453,7 +454,7 @@ async def try_backfill(domains: StrCollection) -> bool:

for dom in domains:
# We don't want to ask our own server for information we don't have
if dom == self.server_name:
if self.is_mine_server_name(dom):
continue

try:
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/federation_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(self, hs: "HomeServer"):
self._notifier = hs.get_notifier()

self._is_mine_id = hs.is_mine_id
self._is_mine_server_name = hs.is_mine_server_name
self._server_name = hs.hostname
self._instance_name = hs.get_instance_name()

Expand Down Expand Up @@ -688,7 +689,7 @@ async def backfill(
server from invalid events (there is probably no point in trying to
re-fetch invalid events from every other HS in the room.)
"""
if dest == self._server_name:
if self._is_mine_server_name(dest):
raise SynapseError(400, "Can't backfill from self.")

events = await self._federation_client.backfill(
Expand Down
4 changes: 2 additions & 2 deletions synapse/handlers/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, hs: "HomeServer"):
self.max_avatar_size = hs.config.server.max_avatar_size
self.allowed_avatar_mimetypes = hs.config.server.allowed_avatar_mimetypes

self.server_name = hs.config.server.server_name
self._is_mine_server_name = hs.is_mine_server_name

self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules

Expand Down Expand Up @@ -309,7 +309,7 @@ async def check_avatar_size_and_mime_type(self, mxc: str) -> bool:
else:
server_name = host

if server_name == self.server_name:
if self._is_mine_server_name(server_name):
media_info = await self.store.get_local_media(media_id)
else:
media_info = await self.store.get_cached_remote_media(server_name, media_id)
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/sso.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def __init__(self, hs: "HomeServer"):
self._clock = hs.get_clock()
self._store = hs.get_datastores().main
self._server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self._registration_handler = hs.get_registration_handler()
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
Expand Down Expand Up @@ -802,7 +803,7 @@ def is_allowed_mime_type(content_type: str) -> bool:
if profile["avatar_url"] is not None:
server_name = profile["avatar_url"].split("/")[-2]
media_id = profile["avatar_url"].split("/")[-1]
if server_name == self._server_name:
if self._is_mine_server_name(server_name):
media = await self._media_repo.store.get_local_media(media_id)
if media is not None and upload_name == media["upload_name"]:
logger.info("skipping saving the user avatar")
Expand Down
3 changes: 2 additions & 1 deletion synapse/handlers/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self, hs: "HomeServer"):
self.server_name = hs.config.server.server_name
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
self.is_mine_server_name = hs.is_mine_server_name

self.federation = None
if hs.should_send_federation():
Expand Down Expand Up @@ -153,7 +154,7 @@ async def _push_remote(self, member: RoomMember, typing: bool) -> None:
member.room_id
)
for domain in hosts:
if domain != self.server_name:
if not self.is_mine_server_name(domain):
logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu(
destination=domain,
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/admin/media.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,15 +258,15 @@ class DeleteMediaByID(RestServlet):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.auth = hs.get_auth()
self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name
self.media_repository = hs.get_media_repository()

async def on_DELETE(
self, request: SynapseRequest, server_name: str, media_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)

if self.server_name != server_name:
if not self._is_mine_server_name(server_name):
raise SynapseError(HTTPStatus.BAD_REQUEST, "Can only delete local media")

if await self.store.get_local_media(media_id) is None:
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/client/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
limit = None

handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server.server_name:
if server and not self.hs.is_mine_server_name(server):
# Ensure the server is valid.
try:
parse_and_validate_server_name(server)
Expand Down Expand Up @@ -551,7 +551,7 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
limit = None

handler = self.hs.get_room_list_handler()
if server and server != self.hs.config.server.server_name:
if server and not self.hs.is_mine_server_name(server):
# Ensure the server is valid.
try:
parse_and_validate_server_name(server)
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/media/download_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DownloadResource(DirectServeJsonResource):
def __init__(self, hs: "HomeServer", media_repo: "MediaRepository"):
super().__init__()
self.media_repo = media_repo
self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name

async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request)
Expand All @@ -59,7 +59,7 @@ async def _async_render_GET(self, request: SynapseRequest) -> None:
b"no-referrer",
)
server_name, media_id, name = parse_media_id(request)
if server_name == self.server_name:
if self._is_mine_server_name(server_name):
await self.media_repo.get_local_media(request, media_id, name)
else:
allow_remote = parse_boolean(request, "allow_remote", default=True)
Expand Down
4 changes: 2 additions & 2 deletions synapse/rest/media/thumbnail_resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
self.media_repo = media_repo
self.media_storage = media_storage
self.dynamic_thumbnails = hs.config.media.dynamic_thumbnails
self.server_name = hs.hostname
self._is_mine_server_name = hs.is_mine_server_name

async def _async_render_GET(self, request: SynapseRequest) -> None:
set_cors_headers(request)
Expand All @@ -71,7 +71,7 @@ async def _async_render_GET(self, request: SynapseRequest) -> None:
# TODO Parse the Accept header to get an prioritised list of thumbnail types.
m_type = "image/png"

if server_name == self.server_name:
if self._is_mine_server_name(server_name):
if self.dynamic_thumbnails:
await self._select_or_generate_local_thumbnail(
request, media_id, width, height, method, m_type
Expand Down
4 changes: 4 additions & 0 deletions synapse/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,10 @@ def is_mine_id(self, string: str) -> bool:
return False
return localpart_hostname[1] == self.hostname

def is_mine_server_name(self, server_name: str) -> bool:
"""Determines whether a server name refers to this homeserver."""
return server_name == self.hostname

@cache_in_self
def get_clock(self) -> Clock:
return Clock(self._reactor)
Expand Down
2 changes: 1 addition & 1 deletion synapse/storage/databases/main/room.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,7 +996,7 @@ async def quarantine_media_by_id(
If it is `None` media will be removed from quarantine
"""
logger.info("Quarantining media: %s/%s", server_name, media_id)
is_local = server_name == self.config.server.server_name
is_local = self.hs.is_mine_server_name(server_name)

def _quarantine_media_by_id_txn(txn: LoggingTransaction) -> int:
local_mxcs = [media_id] if is_local else []
Expand Down
Loading