diff --git a/src/synology_dsm/synology_dsm.py b/src/synology_dsm/synology_dsm.py index 039931b..0c9b013 100644 --- a/src/synology_dsm/synology_dsm.py +++ b/src/synology_dsm/synology_dsm.py @@ -11,7 +11,15 @@ from typing import Any, Coroutine, TypedDict from urllib.parse import quote, urlencode -from aiohttp import ClientError, ClientSession, ClientTimeout, MultipartWriter, hdrs +from aiohttp import ( + ClientError, + ClientResponse, + ClientSession, + ClientTimeout, + MultipartWriter, + StreamReader, + hdrs, +) from yarl import URL from .api import SynoBaseApi @@ -239,13 +247,13 @@ def device_token(self) -> str | None: async def get( self, api: str, method: str, params: dict | None = None, **kwargs: Any - ) -> bytes | dict | str: + ) -> bytes | dict | str | StreamReader: """Handles API GET request.""" return await self._request("GET", api, method, params, **kwargs) async def post( self, api: str, method: str, params: dict | None = None, **kwargs: Any - ) -> bytes | dict | str: + ) -> bytes | dict | str | StreamReader: """Handles API POST request.""" return await self._request("POST", api, method, params, **kwargs) @@ -310,8 +318,9 @@ async def _request( method: str, params: dict | None = None, retry_once: bool = True, + raw_response_content: bool = False, **kwargs: Any, - ) -> bytes | dict | str: + ) -> bytes | dict | str | StreamReader: """Handles API request.""" url, params, kwargs = await self._prepare_request(api, method, params, **kwargs) @@ -319,28 +328,61 @@ async def _request( self._debuglog("---------------------------------------------------------") self._debuglog("API: " + api) self._debuglog("Request Method: " + request_method) - response = await self._execute_request(request_method, url, params, **kwargs) + + try: + response = await self._execute_request( + request_method, url, params, **kwargs + ) + if response.status != 200: + # We got a 400, 401 or 404 ... + raise ClientError(response) + except (ClientError, asyncio.TimeoutError) as exp: + raise SynologyDSMRequestException(exp) from exp + + content_type = response.headers.get("Content-Type", "").split(";")[0] + result: bytes | dict | str | StreamReader + if raw_response_content: + result = response.content + elif content_type in [ + "application/json", + "text/json", + "text/plain", # Can happen with some API + ]: + try: + result = dict(await response.json(content_type=content_type)) + except JSONDecodeError as exp: + raise SynologyDSMRequestException(exp) from exp + elif content_type == "application/octet-stream" or content_type.startswith( + "image" + ): + result = await response.read() + else: + result = await response.text() + self._debuglog("Successful returned data") - self._debuglog("RESPONSE: " + str(response)) + if raw_response_content: + self._debuglog("RESPONSE: omitted since raw response requested") + else: + self._debuglog("RESPONSE: " + str(result)) # Handle data errors - if isinstance(response, dict) and response.get("error") and api != API_AUTH: - self._debuglog("Session error: " + str(response["error"]["code"])) - if response["error"]["code"] == 119 and retry_once: + if isinstance(result, dict) and result.get("error") and api != API_AUTH: + self._debuglog("Session error: " + str(result["error"]["code"])) + if result["error"]["code"] == 119 and retry_once: # Session ID not valid # see https://github.com/aerialls/synology-srm/pull/3 self._session_id = None self._syno_token = None return await self._request(request_method, api, method, params, False) raise SynologyDSMAPIErrorException( - api, response["error"]["code"], response["error"].get("errors") + api, result["error"]["code"], result["error"].get("errors") ) - return response + return result async def _execute_request( self, method: str, url: URL, params: dict, **kwargs: Any - ) -> bytes | dict | str: + ) -> ClientResponse: """Function to execute and handle a request.""" # special handling for spaces in parameters # because yarl.URL does encode a space as + instead of %20 @@ -348,86 +390,62 @@ async def _execute_request( query = urlencode(params, safe="?/:@-._~!$'()*,", quote_via=quote) url_encoded = url.join(URL(f"?{query}", encoded=True)) - try: - if method == "GET": - response = await self._session.get( - url_encoded, timeout=self._aiohttp_timeout, **kwargs + if method == "GET": + response = await self._session.get( + url_encoded, timeout=self._aiohttp_timeout, **kwargs + ) + elif method == "GET" and params.get("api") == SynoFileStation.DOWNLOAD_API_KEY: + pass + elif method == "POST" and params.get("api") == SynoFileStation.UPLOAD_API_KEY: + content = kwargs.pop("content") + path = kwargs.pop("path") + filename = kwargs.pop("filename") + + boundary = md5( + str(url_encoded).encode("utf-8"), usedforsecurity=False + ).hexdigest() + with MultipartWriter("form-data", boundary=boundary) as mp: + part = mp.append(path) + part.headers.pop(hdrs.CONTENT_TYPE) + part.set_content_disposition("form-data", name="path") + + part = mp.append(content) + part.headers.pop(hdrs.CONTENT_TYPE) + part.set_content_disposition( + "form-data", name="file", filename=filename ) - elif ( - method == "POST" and params.get("api") == SynoFileStation.UPLOAD_API_KEY - ): - content = kwargs.pop("content") - path = kwargs.pop("path") - filename = kwargs.pop("filename") - - boundary = md5( - str(url_encoded).encode("utf-8"), usedforsecurity=False - ).hexdigest() - with MultipartWriter("form-data", boundary=boundary) as mp: - part = mp.append(path) - part.headers.pop(hdrs.CONTENT_TYPE) - part.set_content_disposition("form-data", name="path") - - part = mp.append(content) - part.headers.pop(hdrs.CONTENT_TYPE) - part.set_content_disposition( - "form-data", name="file", filename=filename - ) - part.headers.add(hdrs.CONTENT_TYPE, "application/octet-stream") - - response = await self._session.post( - url_encoded, - timeout=ClientTimeout(connect=10.0, total=43200.0), - data=mp, - ) - elif method == "POST": - data = {} - if params is not None: - data.update(params) - data.update(kwargs.pop("data", {})) - data["mimeType"] = "application/json" - kwargs["data"] = data - self._debuglog("POST data: " + str(data)) + part.headers.add(hdrs.CONTENT_TYPE, "application/octet-stream") response = await self._session.post( - url_encoded, timeout=self._aiohttp_timeout, **kwargs + url_encoded, + timeout=ClientTimeout(connect=10.0, total=43200.0), + data=mp, ) + elif method == "POST": + data = {} + if params is not None: + data.update(params) + data.update(kwargs.pop("data", {})) + data["mimeType"] = "application/json" + kwargs["data"] = data + self._debuglog("POST data: " + str(data)) + + response = await self._session.post( + url_encoded, timeout=self._aiohttp_timeout, **kwargs + ) - # mask sesitive parameters - if _LOGGER.isEnabledFor(logging.DEBUG) or self._debugmode: - response_url = response.url # pylint: disable=E0606 - for param in SENSITIV_PARAMS: - if params is not None and params.get(param): - response_url = response_url.update_query({param: "*********"}) - self._debuglog("Request url: " + str(response_url)) - self._debuglog("Request headers: " + str(response.request_info.headers)) - self._debuglog("Response status_code: " + str(response.status)) - self._debuglog("Response headers: " + str(dict(response.headers))) - - if response.status == 200: - # We got a DSM response - content_type = response.headers.get("Content-Type", "").split(";")[0] - - if content_type in [ - "application/json", - "text/json", - "text/plain", # Can happen with some API - ]: - return dict(await response.json(content_type=content_type)) - - if ( - content_type == "application/octet-stream" - or content_type.startswith("image") - ): - return await response.read() - - return await response.text() - - # We got a 400, 401 or 404 ... - raise ClientError(response) - - except (ClientError, asyncio.TimeoutError, JSONDecodeError) as exp: - raise SynologyDSMRequestException(exp) from exp + # mask sesitive parameters + if _LOGGER.isEnabledFor(logging.DEBUG) or self._debugmode: + response_url = response.url # pylint: disable=E0606 + for param in SENSITIV_PARAMS: + if params is not None and params.get(param): + response_url = response_url.update_query({param: "*********"}) + self._debuglog("Request url: " + str(response_url)) + self._debuglog("Request headers: " + str(response.request_info.headers)) + self._debuglog("Response status_code: " + str(response.status)) + self._debuglog("Response headers: " + str(dict(response.headers))) + + return response async def update( self, with_information: bool = False, with_network: bool = False