Skip to content

Commit

Permalink
allow to use raw response content (StreamReader)
Browse files Browse the repository at this point in the history
  • Loading branch information
mib1185 committed Jan 5, 2025
1 parent 8f62661 commit b5b6feb
Showing 1 changed file with 105 additions and 87 deletions.
192 changes: 105 additions & 87 deletions src/synology_dsm/synology_dsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@
from typing import Any, Coroutine, TypedDict
from urllib.parse import quote, urlencode

from aiohttp import ClientError, ClientSession, ClientTimeout, MultipartWriter, hdrs
from aiohttp import (
ClientError,
ClientResponse,
ClientSession,
ClientTimeout,
MultipartWriter,
StreamReader,
hdrs,
)
from yarl import URL

from .api import SynoBaseApi
Expand Down Expand Up @@ -239,13 +247,13 @@ def device_token(self) -> str | None:

async def get(
self, api: str, method: str, params: dict | None = None, **kwargs: Any
) -> bytes | dict | str:
) -> bytes | dict | str | StreamReader:
"""Handles API GET request."""
return await self._request("GET", api, method, params, **kwargs)

async def post(
self, api: str, method: str, params: dict | None = None, **kwargs: Any
) -> bytes | dict | str:
) -> bytes | dict | str | StreamReader:
"""Handles API POST request."""
return await self._request("POST", api, method, params, **kwargs)

Expand Down Expand Up @@ -310,124 +318,134 @@ async def _request(
method: str,
params: dict | None = None,
retry_once: bool = True,
raw_response_content: bool = False,
**kwargs: Any,
) -> bytes | dict | str:
) -> bytes | dict | str | StreamReader:
"""Handles API request."""
url, params, kwargs = await self._prepare_request(api, method, params, **kwargs)

# Request data
self._debuglog("---------------------------------------------------------")
self._debuglog("API: " + api)
self._debuglog("Request Method: " + request_method)
response = await self._execute_request(request_method, url, params, **kwargs)

try:
response = await self._execute_request(
request_method, url, params, **kwargs
)
if response.status != 200:
# We got a 400, 401 or 404 ...
raise ClientError(response)
except (ClientError, asyncio.TimeoutError) as exp:
raise SynologyDSMRequestException(exp) from exp

content_type = response.headers.get("Content-Type", "").split(";")[0]
result: bytes | dict | str | StreamReader
if raw_response_content:
result = response.content
elif content_type in [
"application/json",
"text/json",
"text/plain", # Can happen with some API
]:
try:
result = dict(await response.json(content_type=content_type))
except JSONDecodeError as exp:
raise SynologyDSMRequestException(exp) from exp
elif content_type == "application/octet-stream" or content_type.startswith(
"image"
):
result = await response.read()
else:
result = await response.text()

self._debuglog("Successful returned data")
self._debuglog("RESPONSE: " + str(response))
if raw_response_content:
self._debuglog("RESPONSE: omitted since raw response requested")
else:
self._debuglog("RESPONSE: " + str(result))

# Handle data errors
if isinstance(response, dict) and response.get("error") and api != API_AUTH:
self._debuglog("Session error: " + str(response["error"]["code"]))
if response["error"]["code"] == 119 and retry_once:
if isinstance(result, dict) and result.get("error") and api != API_AUTH:
self._debuglog("Session error: " + str(result["error"]["code"]))
if result["error"]["code"] == 119 and retry_once:
# Session ID not valid
# see https://github.com/aerialls/synology-srm/pull/3
self._session_id = None
self._syno_token = None
return await self._request(request_method, api, method, params, False)
raise SynologyDSMAPIErrorException(
api, response["error"]["code"], response["error"].get("errors")
api, result["error"]["code"], result["error"].get("errors")
)

return response
return result

async def _execute_request(
self, method: str, url: URL, params: dict, **kwargs: Any
) -> bytes | dict | str:
) -> ClientResponse:
"""Function to execute and handle a request."""
# special handling for spaces in parameters
# because yarl.URL does encode a space as + instead of %20
# safe extracted from yarl.URL._QUERY_PART_QUOTER
query = urlencode(params, safe="?/:@-._~!$'()*,", quote_via=quote)
url_encoded = url.join(URL(f"?{query}", encoded=True))

try:
if method == "GET":
response = await self._session.get(
url_encoded, timeout=self._aiohttp_timeout, **kwargs
if method == "GET":
response = await self._session.get(
url_encoded, timeout=self._aiohttp_timeout, **kwargs
)
elif method == "GET" and params.get("api") == SynoFileStation.DOWNLOAD_API_KEY:
pass
elif method == "POST" and params.get("api") == SynoFileStation.UPLOAD_API_KEY:
content = kwargs.pop("content")
path = kwargs.pop("path")
filename = kwargs.pop("filename")

boundary = md5(
str(url_encoded).encode("utf-8"), usedforsecurity=False
).hexdigest()
with MultipartWriter("form-data", boundary=boundary) as mp:
part = mp.append(path)
part.headers.pop(hdrs.CONTENT_TYPE)
part.set_content_disposition("form-data", name="path")

part = mp.append(content)
part.headers.pop(hdrs.CONTENT_TYPE)
part.set_content_disposition(
"form-data", name="file", filename=filename
)
elif (
method == "POST" and params.get("api") == SynoFileStation.UPLOAD_API_KEY
):
content = kwargs.pop("content")
path = kwargs.pop("path")
filename = kwargs.pop("filename")

boundary = md5(
str(url_encoded).encode("utf-8"), usedforsecurity=False
).hexdigest()
with MultipartWriter("form-data", boundary=boundary) as mp:
part = mp.append(path)
part.headers.pop(hdrs.CONTENT_TYPE)
part.set_content_disposition("form-data", name="path")

part = mp.append(content)
part.headers.pop(hdrs.CONTENT_TYPE)
part.set_content_disposition(
"form-data", name="file", filename=filename
)
part.headers.add(hdrs.CONTENT_TYPE, "application/octet-stream")

response = await self._session.post(
url_encoded,
timeout=ClientTimeout(connect=10.0, total=43200.0),
data=mp,
)
elif method == "POST":
data = {}
if params is not None:
data.update(params)
data.update(kwargs.pop("data", {}))
data["mimeType"] = "application/json"
kwargs["data"] = data
self._debuglog("POST data: " + str(data))
part.headers.add(hdrs.CONTENT_TYPE, "application/octet-stream")

response = await self._session.post(
url_encoded, timeout=self._aiohttp_timeout, **kwargs
url_encoded,
timeout=ClientTimeout(connect=10.0, total=43200.0),
data=mp,
)
elif method == "POST":
data = {}
if params is not None:
data.update(params)
data.update(kwargs.pop("data", {}))
data["mimeType"] = "application/json"
kwargs["data"] = data
self._debuglog("POST data: " + str(data))

response = await self._session.post(
url_encoded, timeout=self._aiohttp_timeout, **kwargs
)

# mask sesitive parameters
if _LOGGER.isEnabledFor(logging.DEBUG) or self._debugmode:
response_url = response.url # pylint: disable=E0606
for param in SENSITIV_PARAMS:
if params is not None and params.get(param):
response_url = response_url.update_query({param: "*********"})
self._debuglog("Request url: " + str(response_url))
self._debuglog("Request headers: " + str(response.request_info.headers))
self._debuglog("Response status_code: " + str(response.status))
self._debuglog("Response headers: " + str(dict(response.headers)))

if response.status == 200:
# We got a DSM response
content_type = response.headers.get("Content-Type", "").split(";")[0]

if content_type in [
"application/json",
"text/json",
"text/plain", # Can happen with some API
]:
return dict(await response.json(content_type=content_type))

if (
content_type == "application/octet-stream"
or content_type.startswith("image")
):
return await response.read()

return await response.text()

# We got a 400, 401 or 404 ...
raise ClientError(response)

except (ClientError, asyncio.TimeoutError, JSONDecodeError) as exp:
raise SynologyDSMRequestException(exp) from exp
# mask sesitive parameters
if _LOGGER.isEnabledFor(logging.DEBUG) or self._debugmode:
response_url = response.url # pylint: disable=E0606
for param in SENSITIV_PARAMS:
if params is not None and params.get(param):
response_url = response_url.update_query({param: "*********"})
self._debuglog("Request url: " + str(response_url))
self._debuglog("Request headers: " + str(response.request_info.headers))
self._debuglog("Response status_code: " + str(response.status))
self._debuglog("Response headers: " + str(dict(response.headers)))

return response

async def update(
self, with_information: bool = False, with_network: bool = False
Expand Down

0 comments on commit b5b6feb

Please sign in to comment.