Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MSC3916 by adding _matrix/client/v1/media/download endpoint #17365

Merged
merged 13 commits into from
Jul 2, 2024
Merged
30 changes: 28 additions & 2 deletions synapse/federation/federation_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1880,8 +1880,34 @@ async def federation_download_media(
max_timeout_ms: int,
download_ratelimiter: Ratelimiter,
ip_address: str,
) -> Tuple[int, Dict[bytes, List[bytes]], bytes]:
return await self.transport_layer.federation_download_media(
) -> Union[
Tuple[int, Dict[bytes, List[bytes]], bytes],
Tuple[int, Dict[bytes, List[bytes]]],
]:
try:
return await self.transport_layer.federation_download_media(
destination,
media_id,
output_stream=output_stream,
max_size=max_size,
max_timeout_ms=max_timeout_ms,
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
except HttpResponseException as e:
# If an error is received that is due to an unrecognised endpoint,
# fallback to the _matrix/media/v3/download endpoint. Otherwise, consider it a legitimate error
# and raise.
if not is_unknown_endpoint(e):
raise

logger.debug(
"Couldn't download media %s/%s over _matrix/federation/v1/media/download, falling back to _matrix/media/v3/download path",
destination,
media_id,
)

return await self.transport_layer.download_media_v3(
destination,
media_id,
output_stream=output_stream,
Expand Down
8 changes: 7 additions & 1 deletion synapse/media/media_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ async def _federation_download_remote_file(

async with self.media_storage.store_into_file(file_info) as (f, fname):
try:
length, headers, json = await self.client.federation_download_media(
res = await self.client.federation_download_media(
server_name,
media_id,
output_stream=f,
Expand All @@ -839,6 +839,12 @@ async def _federation_download_remote_file(
download_ratelimiter=download_ratelimiter,
ip_address=ip_address,
)
# if we had to fall back to the _matrix/media endpoint it will only return
# the headers and length, check the length of the tuple before unpacking
if len(res) == 3:
length, headers, json = res
else:
length, headers = res
except RequestSendFailed as e:
logger.warning(
"Request failed fetching remote media %s/%s: %r",
Expand Down
94 changes: 87 additions & 7 deletions tests/rest/client/test_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,11 @@
import json
import os
import re
from typing import Any, BinaryIO, ClassVar, Dict, List, Optional, Sequence, Tuple, Type
from typing import Any, BinaryIO, Dict, List, Optional, Sequence, Tuple, Type
from unittest.mock import MagicMock, Mock, patch
from urllib import parse
from urllib.parse import quote, urlencode

from parameterized import parameterized_class

from twisted.internet import defer
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
Expand Down Expand Up @@ -60,7 +58,6 @@
from tests import unittest
from tests.media.test_media_storage import (
SVG,
TestImage,
empty_file,
small_lossless_webp,
small_png,
Expand Down Expand Up @@ -1898,9 +1895,10 @@ def test_file_download(self) -> None:
input_values = [(x,) for x in test_images]


@parameterized_class(("test_image",), input_values)
# @parameterized_class(("test_image",), input_values)
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
class DownloadTestCase(unittest.HomeserverTestCase):
test_image: ClassVar[TestImage]
# test_image: ClassVar[TestImage]
H-Shay marked this conversation as resolved.
Show resolved Hide resolved
test_image = SVG
servlets = [
media.register_servlets,
login.register_servlets,
Expand All @@ -1910,7 +1908,7 @@ class DownloadTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.fetches: List[
Tuple[
"Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]], bytes]]]",
"Deferred[Any]",
str,
str,
Optional[QueryParams],
Expand Down Expand Up @@ -1951,9 +1949,42 @@ def write_err(f: Failure) -> Failure:
d_after_callback = d.addCallbacks(write_to, write_err)
return make_deferred_yieldable(d_after_callback)

def get_file(
destination: str,
path: str,
output_stream: BinaryIO,
download_ratelimiter: Ratelimiter,
ip_address: Any,
max_size: int,
args: Optional[QueryParams] = None,
retry_on_dns_fail: bool = True,
ignore_backoff: bool = False,
follow_redirects: bool = False,
) -> "Deferred[Tuple[int, Dict[bytes, List[bytes]]]]":
"""A mock for MatrixFederationHttpClient.get_file."""

def write_to(
r: Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]
) -> Tuple[int, Dict[bytes, List[bytes]]]:
data, response = r
output_stream.write(data)
return response

def write_err(f: Failure) -> Failure:
f.trap(HttpResponseException)
output_stream.write(f.value.response)
return f

d: Deferred[Tuple[bytes, Tuple[int, Dict[bytes, List[bytes]]]]] = Deferred()
self.fetches.append((d, destination, path, args))
# Note that this callback changes the value held by d.
d_after_callback = d.addCallbacks(write_to, write_err)
return make_deferred_yieldable(d_after_callback)

# Mock out the homeserver's MatrixFederationHttpClient
client = Mock()
client.federation_get_file = federation_get_file
client.get_file = get_file

self.storage_path = self.mktemp()
self.media_store_path = self.mktemp()
Expand Down Expand Up @@ -2128,3 +2159,52 @@ def test_cross_origin_resource_policy_header(self) -> None:
headers.getRawHeaders(b"Cross-Origin-Resource-Policy"),
[b"cross-origin"],
)

def test_unknown_federation_endpoint(self) -> None:
"""
Test that if the downloadd request to remote federation endpoint returns a 404
we fall back to the _matrix/media endpoint
"""
channel = self.make_request(
"GET",
f"/_matrix/client/v1/media/download/{self.remote}/{self.media_id}",
shorthand=False,
await_result=False,
access_token=self.tok,
)
self.pump()

# We've made one fetch, to example.com, using the media URL, and asking
# the other server not to do a remote fetch
self.assertEqual(len(self.fetches), 1)
self.assertEqual(self.fetches[0][1], "example.com")
self.assertEqual(
self.fetches[0][2], f"/_matrix/federation/v1/media/download/{self.media_id}"
)

# The result which says the endpoint is unknown.
unknown_endpoint = b'{"errcode":"M_UNRECOGNIZED","error":"Unknown request"}'
self.fetches[0][0].errback(
HttpResponseException(404, "NOT FOUND", unknown_endpoint)
)

self.pump()

# There should now be another request to the _matrix/media/v3/download URL.
self.assertEqual(len(self.fetches), 2)
self.assertEqual(self.fetches[1][1], "example.com")
self.assertEqual(
self.fetches[1][2],
f"/_matrix/media/v3/download/example.com/{self.media_id}",
)

headers = {
b"Content-Length": [b"%d" % (len(self.test_image.data))],
}

self.fetches[1][0].callback(
(self.test_image.data, (len(self.test_image.data), headers))
)

self.pump()
self.assertEqual(channel.code, 200)
Loading