Skip to content

Commit

Permalink
Adapt code to mypy standards
Browse files Browse the repository at this point in the history
  • Loading branch information
shaiarmis committed Mar 11, 2025
1 parent 0532367 commit 677770e
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 21 deletions.
8 changes: 7 additions & 1 deletion armis_sdk/clients/network_equipment_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,12 @@ async def main():
asyncio.run(main())
```
"""
if site.id is None:
raise ArmisError("The property 'id' must be set.")

if not network_equipment_device_ids:
return

await self._insert(site.id, set(network_equipment_device_ids))

async def update(self, site: Site):
Expand Down Expand Up @@ -72,6 +75,9 @@ async def main():
```
"""

if site.id is None:
raise ArmisError("The property 'id' must be set.")

if site.network_equipment_device_ids is None:
raise ArmisError("The property 'network_equipment_device_ids' must be set.")

Expand Down Expand Up @@ -117,5 +123,5 @@ async def _insert(self, site_id: int, network_equipment_device_ids: Set[int]):
async def _list(self, site_id) -> List[int]:
async with self._armis_client.client() as client:
response = await client.get(f"/api/v1/sites/{site_id}/network-equipment/")
data = self._get_data(response)
data = self._get_dict(response)
return data["networkEquipmentDeviceIds"]
5 changes: 4 additions & 1 deletion armis_sdk/clients/site_integrations_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ async def main():
```
"""

if site.id is None:
raise ArmisError("The property 'id' must be set.")

if site.integration_ids is None:
raise ArmisError("The property 'integration_ids' must be set.")

Expand Down Expand Up @@ -101,5 +104,5 @@ async def _insert(self, site_id: int, integration_ids: Set[int]):
async def _list(self, site_id: int) -> List[int]:
async with self._armis_client.client() as client:
response = await client.get(f"/api/v1/sites/{site_id}/integrations-ids/")
data = self._get_data(response)
data = self._get_dict(response)
return data["integrationIds"]
4 changes: 2 additions & 2 deletions armis_sdk/clients/sites_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ async def main():
)
async with self._armis_client.client() as client:
response = await client.post("/api/v1/sites/", json=payload)
data = self._get_data(response)
data = self._get_dict(response)
created_site = site.model_copy(update={"id": int(data["id"])}, deep=True)

if site.network_equipment_device_ids:
Expand Down Expand Up @@ -142,7 +142,7 @@ async def main():
"""
async with self._armis_client.client() as client:
response = await client.get(f"/api/v1/sites/{site_id}/")
data = self._get_data(response)
data: dict = self._get_dict(response)
return Site.model_validate(data)

async def hierarchy(self) -> List[Site]:
Expand Down
18 changes: 13 additions & 5 deletions armis_sdk/core/armis_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import httpx

from armis_sdk.core import response_utils
from armis_sdk.core.armis_error import ArmisError

AUTHORIZATION = "Authorization"

Expand All @@ -27,17 +28,24 @@ def __init__(self, base_url: str, secret_key: str):
self._base_url = base_url
self._secret_key = secret_key
self._access_token: Optional[str] = None
self._expires_at: Optional[datetime] = None
self._expires_at: Optional[datetime.datetime] = None

def auth_flow(
self, request: httpx.Request
) -> typing.Generator[httpx.Request, httpx.Response, None]:
if self._access_token is None or self._expires_at < datetime.datetime.now(
datetime.timezone.utc
if (
self._access_token is None
or self._expires_at is None
or self._expires_at < datetime.datetime.now(datetime.timezone.utc)
):
access_token_response = yield self._build_access_token_request()
self._update_access_token(access_token_response)

if self._access_token is None:
raise ArmisError(
"Something went wrong, there is no access token available."
)

request.headers[AUTHORIZATION] = self._access_token
response = yield request

Expand All @@ -57,7 +65,7 @@ def _build_access_token_request(self):

def _update_access_token(self, response: httpx.Response):
response_utils.raise_for_status(response)
parsed = response_utils.parse_response(response)
data = parsed.get("data")
parsed = response_utils.parse_response(response, dict)
data: dict = parsed.get("data") or {}
self._access_token = data["access_token"]
self._expires_at = datetime.datetime.fromisoformat(data["expiration_utc"])
3 changes: 2 additions & 1 deletion armis_sdk/core/armis_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

from typing import List
from typing import Optional

from httpx import HTTPStatusError

Expand All @@ -21,7 +22,7 @@ class ResponseError(ArmisError):
For example, if the server returns 400 for invalid input, an instance of this class will be raised.
"""

def __init__(self, *args, response_errors: List[HTTPStatusError] = None):
def __init__(self, *args, response_errors: Optional[List[HTTPStatusError]] = None):
super().__init__(*args)
self.response_errors = response_errors

Expand Down
24 changes: 19 additions & 5 deletions armis_sdk/core/base_entity_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from typing import AsyncIterator
from typing import Optional
from typing import Type
from typing import Union
from typing import TypeVar

import httpx

from armis_sdk.core import response_utils
from armis_sdk.core.armis_client import ArmisClient
from armis_sdk.core.armis_error import ResponseError
from armis_sdk.core.base_entity import BaseEntityT

ARMIS_CLIENT_ID = "ARMIS_CLIENT_ID"
Expand All @@ -16,17 +17,30 @@
ARMIS_TENANT = "ARMIS_TENANT"
DEFAULT_PAGE_LENGTH = 100

DataTypeT = TypeVar("DataTypeT", dict, list)


class BaseEntityClient: # pylint: disable=too-few-public-methods

def __init__(self, armis_client: Optional[ArmisClient] = None) -> None:
self._armis_client = armis_client or ArmisClient()

@classmethod
def _get_data(cls, response: httpx.Response) -> Optional[Union[dict, list]]:
def _get_data(
cls,
response: httpx.Response,
data_type: Type[DataTypeT],
) -> DataTypeT:
response_utils.raise_for_status(response)
parsed = response_utils.parse_response(response)
return parsed.get("data")
parsed = response_utils.parse_response(response, dict)
data = parsed.get("data")
if not isinstance(data, data_type):
raise ResponseError("Response data represents neither a dict nor a list.")
return data

@classmethod
def _get_dict(cls, response: httpx.Response):
return cls._get_data(response, dict)

async def _paginate(
self, url: str, key: str, model: Type[BaseEntityT]
Expand All @@ -36,7 +50,7 @@ async def _paginate(
from_ = 0
while from_ is not None:
params = {"from": from_, "length": page_size}
data = self._get_data(await client.get(url, params=params))
data = self._get_dict(await client.get(url, params=params))
items = data[key]
for item in items:
yield model.model_validate(item)
Expand Down
18 changes: 13 additions & 5 deletions armis_sdk/core/response_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from json import JSONDecodeError
from typing import Optional
from typing import Union
from typing import Type
from typing import TypeVar

import httpx
from httpx import HTTPStatusError
Expand All @@ -10,10 +10,18 @@
from armis_sdk.core.armis_error import NotFoundError
from armis_sdk.core.armis_error import ResponseError

DataTypeT = TypeVar("DataTypeT", dict, list)

def parse_response(response: httpx.Response) -> Optional[Union[dict, list]]:

def parse_response(
response: httpx.Response,
data_type: Type[DataTypeT],
) -> DataTypeT:
try:
return response.json()
response_data = response.json()
if isinstance(response_data, data_type):
return response_data
raise ResponseError("Response body represents neither a dict nor a list.")
except JSONDecodeError as error:
message = f"Response body is not a valid JSON: {response.text}"
raise ResponseError(message) from error
Expand All @@ -23,7 +31,7 @@ def raise_for_status(response: httpx.Response):
try:
response.raise_for_status()
except HTTPStatusError as error:
parsed = parse_response(error.response)
parsed = parse_response(error.response, dict)
message = parsed.get("message", "Something went wrong.")

if error.response.status_code == httpx.codes.NOT_FOUND:
Expand Down
2 changes: 1 addition & 1 deletion armis_sdk/entities/site.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class Site(BaseEntity):
] = None
"""The ids of the integration associated with the site."""

children: Optional[List["Site"]] = Field(default_factory=list)
children: List["Site"] = Field(default_factory=list)
"""The sub-sites that are directly under this site
(their `parent_id` will match this site's `id`)."""

Expand Down

0 comments on commit 677770e

Please sign in to comment.