Skip to content

Commit

Permalink
Merge pull request #242 from natekspencer/cognito-auth
Browse files Browse the repository at this point in the history
Run blocking calls in the executor
  • Loading branch information
natekspencer authored Jan 17, 2025
2 parents 074bbc7 + 8d70a0e commit aae760b
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 39 deletions.
10 changes: 8 additions & 2 deletions pylitterbot/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ClientSession,
ClientWebSocketResponse,
)
from botocore.exceptions import ClientError

from .event import EVENT_UPDATE
from .exceptions import LitterRobotException, LitterRobotLoginException
Expand Down Expand Up @@ -107,7 +108,7 @@ async def connect(
try:
if not self.session.is_token_valid():
if self.session.has_refresh_token():
await self.session.refresh_token()
await self.session.refresh_tokens()
elif username and password:
await self.session.login(username=username, password=password)
else:
Expand All @@ -121,6 +122,11 @@ async def connect(
if load_pets:
await self.load_pets()

except ClientError as err:
_LOGGER.error(err)
raise LitterRobotLoginException(
f"Unable to login to Litter-Robot: {err.response['message']}"
) from err
except ClientResponseError as ex:
_LOGGER.error(ex)
if ex.status == 401:
Expand Down Expand Up @@ -231,7 +237,7 @@ async def refresh_robots(self) -> None:
async def get_bearer_authorization(self) -> str | None:
"""Return the authorization token."""
if not self.session.is_token_valid():
await self.session.refresh_token()
await self.session.refresh_tokens()
return await self.session.get_bearer_authorization()

async def ws_connect(self, robot: Robot) -> ClientWebSocketResponse:
Expand Down
97 changes: 63 additions & 34 deletions pylitterbot/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@

from __future__ import annotations

import asyncio
import logging
from abc import ABC, abstractmethod
from asyncio import Lock
from functools import partial
from types import TracebackType
from typing import Any, Final, TypeVar, cast

import jwt
from aiohttp import ClientSession
from botocore.exceptions import ClientError
from botocore.exceptions import ClientError, ParamValidationError
from pycognito import Cognito

from .event import EVENT_UPDATE, Event
Expand Down Expand Up @@ -63,30 +65,30 @@ async def patch(self, path: str, **kwargs: Any) -> dict | list[dict] | None:
return await self.request("PATCH", path, **kwargs)

@abstractmethod
async def async_get_access_token(self, **kwargs: Any) -> str | None:
async def async_get_id_token(self, **kwargs: Any) -> str | None:
"""Return a valid access token."""

@abstractmethod
def is_token_valid(self) -> bool:
"""Return `True` if the token is stills valid."""

async def refresh_token(self, ignore_unexpired: bool = False) -> None:
async def refresh_tokens(self, ignore_unexpired: bool = False) -> None:
"""Refresh the access token."""
if self.tokens is None:
return None
async with self._lock:
if not ignore_unexpired and self.is_token_valid():
return
await self._refresh_token()
await self._refresh_tokens()
self.emit(EVENT_UPDATE)

@abstractmethod
async def _refresh_token(self) -> None:
async def _refresh_tokens(self) -> None:
"""Actual implementation to refresh the tokens."""

async def get_bearer_authorization(self) -> str | None:
"""Get the bearer authorization."""
if (access_token := await self.async_get_access_token()) is None:
if (access_token := await self.async_get_id_token()) is None:
return None
return f"Bearer {access_token}"

Expand Down Expand Up @@ -168,17 +170,32 @@ def __init__(
self.__refresh_token = token.get("refresh_token")
self._custom_args: dict = {}

@property
def access_token(self) -> str | None:
"""Return the access token, if any."""
return self._user.access_token if self._user else self.__access_token

@property
def id_token(self) -> str | None:
"""Return the id token, if any."""
return self._user.id_token if self._user else self.__id_token

@property
def refresh_token(self) -> str | None:
"""Return the refresh token, if any."""
return self._user.refresh_token if self._user else self.__refresh_token

@property
def tokens(self) -> dict[str, str] | None:
"""Return the Cognito user tokens."""
user = self.get_user()
if None in (user.access_token, user.id_token, user.refresh_token):
"""Return the tokens."""
if None in (self.access_token, self.id_token):
return None
return {
"access_token": user.access_token,
"id_token": user.id_token,
"refresh_token": user.refresh_token,
token = {
"access_token": self.access_token,
"id_token": self.id_token,
"refresh_token": self.refresh_token,
}
return cast(dict[str, str], token)

def generate_args(self, url: str, **kwargs: Any) -> dict[str, Any]:
"""Generate args."""
Expand All @@ -197,58 +214,70 @@ def is_token_valid(self) -> bool:
return False
try:
jwt.decode(
self.tokens["id_token"],
self.id_token,
options={"verify_signature": False, "verify_exp": True},
leeway=-30,
)
except jwt.ExpiredSignatureError:
return False
return True

async def async_get_access_token(self, **kwargs: Any) -> str | None:
"""Return a valid access token."""
async def async_get_id_token(self, **kwargs: Any) -> str | None:
"""Return a valid id token."""
if self.tokens is None or not self.is_token_valid():
return None
return self.tokens["id_token"]
return self.id_token

async def login(self, username: str, password: str) -> None:
"""Login to the Litter-Robot api and generate a new token."""
self._username = username
self.get_user().authenticate(password=password)
user = await self.get_user()
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, partial(user.authenticate, password=password))
self.emit(EVENT_UPDATE)

async def _refresh_token(self) -> None:
async def _refresh_tokens(self) -> None:
"""Refresh the access token."""
# This should be handled by pycognito automatically, but in case we do get here, we'll manually refresh
_LOGGER.debug("Manually refreshing token")
self.get_user().renew_access_token()
user = await self.get_user()
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, user.renew_access_token)

async def request(
self, method: str, url: str, **kwargs: Any
) -> dict | list[dict] | None:
"""Make a request."""
kwargs = self.generate_args(url, **kwargs)
if not kwargs.pop("skip_auth", False) and not self.is_token_valid():
await self.refresh_token()
await self.refresh_tokens()
return await super().request(method, url, **kwargs)

def get_user(self) -> Cognito:
async def get_user(self) -> Cognito:
"""Return the Cognito user."""
if self._user is None:
self._user = Cognito(
decode(self.USER_POOL_ID),
decode(self.CLIENT_ID),
username=self._username,
access_token=self.__access_token,
id_token=self.__id_token,
refresh_token=self.__refresh_token,
loop = asyncio.get_running_loop()
self._user = await loop.run_in_executor(
None,
partial(
Cognito,
decode(self.USER_POOL_ID),
decode(self.CLIENT_ID),
username=self._username,
access_token=self.__access_token,
id_token=self.__id_token,
refresh_token=self.__refresh_token,
),
)
assert self._user
if self.__access_token and self.__id_token:
try:
self._user.check_token()
await loop.run_in_executor(None, self._user.check_token)
self._user.verify_tokens()
except ClientError as err:
_LOGGER.error(err)
raise err
except ParamValidationError:
# tokens are invalid
pass
if self._username and not self._user.username:
self._user.username = self._username
return self._user
Expand All @@ -258,11 +287,11 @@ def get_user_id(self) -> str | None:
if self.tokens is None:
return None
user_id = jwt.decode(
self.tokens["id_token"],
self.id_token,
options={"verify_signature": False, "verify_exp": False},
)["mid"]
return cast(str, user_id)

def has_refresh_token(self) -> bool:
"""Return `True` if the session has a refresh token."""
return self.tokens is not None and self.tokens["refresh_token"] is not None
return self.tokens is not None and self.refresh_token is not None
6 changes: 4 additions & 2 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,13 @@ async def test_token_refresh(mock_aioresponse: aioresponses) -> None:

async with LitterRobotSession() as session:
assert not session.is_token_valid()
await session.refresh_token()
await session.refresh_tokens()
assert not session.is_token_valid()

async with LitterRobotSession(token=EXPIRED_ACCESS_TOKEN) as session:
assert session.is_token_valid() # pycognito auto refreshes
assert not session.is_token_valid()
await session.patch("localhost")
assert session.is_token_valid()


async def test_custom_headers() -> None:
Expand Down
21 changes: 20 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
"""Test utils module."""

from pylitterbot.utils import REDACTED, decode, encode, redact, round_time, to_timestamp
from pylitterbot.utils import (
REDACTED,
decode,
encode,
first_value,
redact,
round_time,
to_timestamp,
)


def test_round_time_default() -> None:
Expand Down Expand Up @@ -31,3 +39,14 @@ def test_redact() -> None:
data = {"key": "value"}
assert redact(data) == data
assert redact([data, data]) == [data, data]


def test_first_value() -> None:
"""Test looking up values from a dictionary."""
values = {"key1": 1, "key2": 2, "key4": 4}
assert first_value(values, ("key1", "key2")) == 1
assert first_value(values, ("key2", "key3")) == 2
assert first_value(values, ("key3", "key4")) == 4
assert first_value(values, ("key3", "key5")) is None
assert first_value(values, ("key3", "key5"), 0) == 0
assert first_value(None, ("key3", "key5"), 10) == 10

0 comments on commit aae760b

Please sign in to comment.