Skip to content

Commit

Permalink
feat: add argument for aiohttp client session
Browse files Browse the repository at this point in the history
  • Loading branch information
Lash-L committed Feb 21, 2023
1 parent e1dbda2 commit bc19401
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 36 deletions.
18 changes: 13 additions & 5 deletions src/southern_company_api/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,24 @@ class MonthlyUsage:


class Account:
def __init__(self, name: str, primary: bool, number: str, company: Company):
def __init__(
self,
name: str,
primary: bool,
number: str,
company: Company,
session: aiohttp.ClientSession,
):
self.name = name
self.primary = primary
self.number = number
self.company = company
self.hourly_data: typing.Dict[str, HourlyEnergyUsage] = {}
self.daily_data: typing.Dict[str, DailyEnergyUsage] = {}
self.session = session

async def get_service_point_number(self, jwt: str) -> str:
async with aiohttp.ClientSession() as session:
async with self.session as session:
headers = {"Authorization": f"bearer {jwt}"}
# TODO: Is the /GPC for all customers or just GA power?
try:
Expand Down Expand Up @@ -83,7 +91,7 @@ async def get_daily_data(
) -> List[DailyEnergyUsage]:
"""Available 24 hours after"""
"""This is not really tested yet."""
async with aiohttp.ClientSession() as session:
async with self.session as session:
headers = {"Authorization": f"bearer {jwt}"}
params = {
"accountNumber": self.number,
Expand Down Expand Up @@ -165,7 +173,7 @@ async def get_hourly_data(
continue
cur_date = cur_date + datetime.timedelta(days=35)
return return_data
async with aiohttp.ClientSession() as session:
async with self.session as session:
# Needs to check if the data already exist in self.hourly_data to avoid making an unneeded call.
headers = {"Authorization": f"bearer {jwt}"}
params = {
Expand Down Expand Up @@ -226,7 +234,7 @@ async def get_hourly_data(

async def get_month_data(self, jwt: str) -> MonthlyUsage:
"""Gets monthly data such as usage so far"""
async with aiohttp.ClientSession() as session:
async with self.session as session:
headers = {"Authorization": f"bearer {jwt}"}
today = datetime.datetime.now()
first_of_month = today.replace(day=1)
Expand Down
24 changes: 13 additions & 11 deletions src/southern_company_api/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import List

import aiohttp as aiohttp
from aiohttp import ContentTypeError
from aiohttp import ClientSession, ContentTypeError

from southern_company_api.account import Account

Expand All @@ -20,13 +20,13 @@
)


async def get_request_verification_token() -> str:
async def get_request_verification_token(session: ClientSession) -> str:
"""
Get the request verification token, which allows us to get a login session
:return: the verification token
"""
try:
async with aiohttp.ClientSession() as session:
async with session:
http_response = await session.get(
"https://webauth.southernco.com/account/login"
)
Expand All @@ -40,7 +40,8 @@ async def get_request_verification_token() -> str:


class SouthernCompanyAPI:
def __init__(self, username: str, password: str):
def __init__(self, username: str, password: str, session: ClientSession):
self.session = session
self.username = username
self.password = password
self._jwt: typing.Optional[str] = None
Expand Down Expand Up @@ -75,7 +76,7 @@ async def request_token(self) -> str:
self._request_token is None
or datetime.datetime.now() >= self._request_token_expiry
):
self._request_token = await get_request_verification_token()
self._request_token = await get_request_verification_token(self.session)
self._request_token_expiry = datetime.datetime.now() + datetime.timedelta(
hours=3
)
Expand All @@ -85,14 +86,14 @@ async def connect(self) -> None:
"""
Connects to Southern company and gets all accounts
"""
self._request_token = await get_request_verification_token()
self._request_token = await get_request_verification_token(self.session)
self._sc = await self._get_sc_web_token()
self._jwt = await self.get_jwt()
self._accounts = await self.get_accounts()

async def authenticate(self) -> bool:
"""Determines if you can authenticate with Southern Company with given login"""
self._request_token = await get_request_verification_token()
self._request_token = await get_request_verification_token(self.session)
self._sc = await self._get_sc_web_token()
return True

Expand All @@ -112,7 +113,7 @@ async def _get_sc_web_token(self) -> str:
"params": {"ReturnUrl": "null"},
}

async with aiohttp.ClientSession() as session:
async with self.session as session:
async with session.post(
"https://webauth.southernco.com/api/login", json=data, headers=headers
) as response:
Expand Down Expand Up @@ -140,7 +141,7 @@ async def _get_southern_jwt_cookie(self) -> str:
if await self.sc is None:
raise CantReachSouthernCompany("Sc token cannot be refreshed")
data = {"ScWebToken": self._sc}
async with aiohttp.ClientSession() as session:
async with self.session as session:
async with session.post(
"https://customerservice2.southerncompany.com/Account/LoginComplete?"
"ReturnUrl=null",
Expand Down Expand Up @@ -180,7 +181,7 @@ async def get_jwt(self) -> str:
# Now fetch JWT after secondary ScWebToken
# NOTE: This used to be ScWebToken before 02/07/2023
headers = {"Cookie": f"SouthernJwtCookie={swtoken}"}
async with aiohttp.ClientSession() as session:
async with self.session as session:
async with session.get(
"https://customerservice2.southerncompany.com/Account/LoginValidated/"
"JwtToken",
Expand Down Expand Up @@ -223,7 +224,7 @@ async def get_accounts(self) -> List[Account]:
if await self.jwt is None:
raise CantReachSouthernCompany("Can't get jwt. Expired and not refreshed")
headers = {"Authorization": f"bearer {self._jwt}"}
async with aiohttp.ClientSession() as session:
async with self.session as session:
async with session.get(
"https://customerservice2api.southerncompany.com/api/account/"
"getAllAccounts",
Expand All @@ -249,6 +250,7 @@ async def get_accounts(self) -> List[Account]:
primary=account["PrimaryAccount"] == "Y",
number=account["AccountNumber"],
company=COMPANY_MAP.get(account["Company"], Company.GPC),
session=self.session,
)
)
self._accounts = accounts
Expand Down
7 changes: 4 additions & 3 deletions tests/test_account.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import datetime
from unittest.mock import patch

import aiohttp
import pytest

from southern_company_api import Account, Company
from tests import MockResponse, test_get_hourly_usage, test_get_month_data


def test_can_create():
Account("sample", True, "1", Company.GPC)
Account("sample", True, "1", Company.GPC, aiohttp.ClientSession())


@pytest.mark.asyncio
async def test_get_hourly_data():
acc = Account("sample", True, "1", Company.GPC)
acc = Account("sample", True, "1", Company.GPC, aiohttp.ClientSession())
with patch(
"src.southern_company_api.account.aiohttp.ClientSession.get"
) as mock_get, patch(
Expand All @@ -31,7 +32,7 @@ async def test_get_hourly_data():

@pytest.mark.asyncio
async def test_ga_power_get_monthly_data():
acc = Account("sample", True, "1", Company.GPC)
acc = Account("sample", True, "1", Company.GPC, aiohttp.ClientSession())
with patch(
"src.southern_company_api.account.aiohttp.ClientSession.get"
) as mock_get, patch(
Expand Down
35 changes: 18 additions & 17 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# For get_accounts jwt mock. looking for better solution.
import datetime
import typing
from unittest.mock import patch
from unittest.mock import AsyncMock, patch

import aiohttp
import pytest

from southern_company_api import Account
Expand All @@ -26,12 +27,12 @@


def test_can_create():
SouthernCompanyAPI("user", "pass")
SouthernCompanyAPI("user", "pass", aiohttp.ClientSession())


@pytest.mark.asyncio
async def test_get_request_verification_token():
token = await get_request_verification_token()
token = await get_request_verification_token(aiohttp.ClientSession())
assert len(token) > 1


Expand All @@ -45,11 +46,7 @@ async def test_get_request_verification_token():
@pytest.mark.asyncio
async def test_get_request_verification_token_fail():
with pytest.raises(CantReachSouthernCompany):
with patch(
"src.southern_company_api.parser.aiohttp.ClientSession",
side_effect=Exception(),
):
await get_request_verification_token()
await get_request_verification_token(AsyncMock(side_effect=Exception()))


@pytest.mark.asyncio
Expand All @@ -58,7 +55,7 @@ async def test_cant_find_request_token():
"src.southern_company_api.parser.aiohttp.ClientResponse.text", return_value=""
):
with pytest.raises(NoRequestTokenFound):
await get_request_verification_token()
await get_request_verification_token(aiohttp.ClientSession())


@pytest.mark.asyncio
Expand All @@ -70,7 +67,7 @@ async def test_can_authenticate():
) as mock__get_sc_web_token:
mock_get_request_verification_token.return_value = "fake_token"
mock__get_sc_web_token.return_value = "fake_sc"
api = SouthernCompanyAPI("", "")
api = SouthernCompanyAPI("", "", aiohttp.ClientSession())
result = await api.authenticate()
assert result is True
mock_get_request_verification_token.assert_called_once()
Expand All @@ -81,17 +78,21 @@ async def test_can_authenticate():
async def test_ga_power_get_sc_web_token():
with patch("southern_company_api.parser.aiohttp.ClientSession.post") as mock_post:
mock_post.return_value = MockResponse("", 200, "", ga_power_sample_sc_response)
sca = SouthernCompanyAPI("", "")
sca = SouthernCompanyAPI("", "", aiohttp.ClientSession())
sca._request_token = "sample"
response_token = await sca._get_sc_web_token()
assert response_token == "sample_sc_token"


@pytest.mark.asyncio
async def test_get_sc_web_token_wrong_login():
sca = SouthernCompanyAPI("user", "pass")
with pytest.raises(InvalidLogin):
await sca._get_sc_web_token()
sca = SouthernCompanyAPI("user", "pass", aiohttp.ClientSession())
with patch(
"src.southern_company_api.parser.aiohttp.ClientSession.post"
) as mock_post:
mock_post.return_value = MockResponse("", 200, "", {"statusCode": 500})
with pytest.raises(InvalidLogin):
await sca._get_sc_web_token()


@pytest.mark.asyncio
Expand All @@ -102,7 +103,7 @@ async def test_ga_power_get_jwt_cookie():
mock_post.return_value = MockResponse(
"", 200, ga_power_southern_jwt_cookie_header, ""
)
sca = SouthernCompanyAPI("", "")
sca = SouthernCompanyAPI("", "", aiohttp.ClientSession())
sca._sc = ""
sca._sc_expiry = datetime.datetime.now() + datetime.timedelta(hours=3)
token = await sca._get_southern_jwt_cookie()
Expand All @@ -118,7 +119,7 @@ async def test_ga_power_get_jwt():
) as mock_get_cookie:
mock_get.return_value = MockResponse("", 200, ga_power_jwt_header, "")
mock_get_cookie.return_value.__aenter__.return_value = ""
sca = SouthernCompanyAPI("", "")
sca = SouthernCompanyAPI("", "", aiohttp.ClientSession())
token = await sca.get_jwt()
assert token == "sample_jwt"

Expand All @@ -136,7 +137,7 @@ async def mock_jwt(_foo_self: SouthernCompanyAPI) -> str:
return ""

with patch.object(SouthernCompanyAPI, "jwt", new=mock_jwt):
sca = SouthernCompanyAPI("", "")
sca = SouthernCompanyAPI("", "", aiohttp.ClientSession())
response_token: typing.List[Account] = await sca.get_accounts()
assert response_token[0].name == "Home Energy"
assert sca._accounts == response_token

0 comments on commit bc19401

Please sign in to comment.