Skip to content

Commit

Permalink
fix: Added prevention around multiple access token requests
Browse files Browse the repository at this point in the history
  • Loading branch information
BottlecapDave committed Jan 1, 2024
1 parent 6f80f76 commit 0ecd453
Showing 1 changed file with 27 additions and 20 deletions.
47 changes: 27 additions & 20 deletions custom_components/octopus_energy/api_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import aiohttp
from asyncio import TimeoutError
from datetime import (datetime, timedelta, time)
from threading import RLock

from homeassistant.util.dt import (as_utc, now, as_local, parse_datetime)

Expand Down Expand Up @@ -363,6 +364,7 @@ def __init__(self, message: str, errors: list[str]):
self.errors = errors

class OctopusEnergyApiClient:
_refresh_token_lock = RLock()

def __init__(self, api_key, electricity_price_cap = None, gas_price_cap = None, timeout_in_seconds = 15):
if (api_key is None):
Expand All @@ -379,33 +381,38 @@ def __init__(self, api_key, electricity_price_cap = None, gas_price_cap = None,
self._electricity_price_cap = electricity_price_cap
self._gas_price_cap = gas_price_cap

self._timeout = aiohttp.ClientTimeout(total=timeout_in_seconds)
self._timeout = aiohttp.ClientTimeout(total=None, sock_connect=timeout_in_seconds, sock_read=timeout_in_seconds)
self._default_headers = { "user-agent": f'{user_agent_value}/{INTEGRATION_VERSION}' }

async def async_refresh_token(self):
"""Get the user's refresh token"""
if (self._graphql_expiration is not None and (self._graphql_expiration - timedelta(minutes=5)) > now()):
return

try:
async with aiohttp.ClientSession(timeout=self._timeout, headers=self._default_headers) as client:
url = f'{self._base_url}/v1/graphql/'
payload = { "query": api_token_query.format(api_key=self._api_key) }
async with client.post(url, json=payload) as token_response:
token_response_body = await self.__async_read_response__(token_response, url)
if (token_response_body is not None and
"data" in token_response_body and
"obtainKrakenToken" in token_response_body["data"] and
token_response_body["data"]["obtainKrakenToken"] is not None and
"token" in token_response_body["data"]["obtainKrakenToken"]):

self._graphql_token = token_response_body["data"]["obtainKrakenToken"]["token"]
self._graphql_expiration = now() + timedelta(hours=1)
else:
_LOGGER.error("Failed to retrieve auth token")
except TimeoutError:
_LOGGER.warning(f'Failed to connect. Timeout of {self._timeout} exceeded.')
raise TimeoutException()
with self._refresh_token_lock:
# Check that our token wasn't refreshed while waiting for the lock
if (self._graphql_expiration is not None and (self._graphql_expiration - timedelta(minutes=5)) > now()):
return

try:
async with aiohttp.ClientSession(timeout=self._timeout, headers=self._default_headers) as client:
url = f'{self._base_url}/v1/graphql/'
payload = { "query": api_token_query.format(api_key=self._api_key) }
async with client.post(url, json=payload) as token_response:
token_response_body = await self.__async_read_response__(token_response, url)
if (token_response_body is not None and
"data" in token_response_body and
"obtainKrakenToken" in token_response_body["data"] and
token_response_body["data"]["obtainKrakenToken"] is not None and
"token" in token_response_body["data"]["obtainKrakenToken"]):

self._graphql_token = token_response_body["data"]["obtainKrakenToken"]["token"]
self._graphql_expiration = now() + timedelta(hours=1)
else:
_LOGGER.error("Failed to retrieve auth token")
except TimeoutError:
_LOGGER.warning(f'Failed to connect. Timeout of {self._timeout} exceeded.')
raise TimeoutException()

async def async_get_account(self, account_id):
"""Get the user's account"""
Expand Down

0 comments on commit 0ecd453

Please sign in to comment.