Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CI + database unit testing & method generalization #5

Merged
merged 2 commits into from
Oct 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
name: Run Python Tests
on:
push

jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/[email protected]
- name: Setup Python
uses: actions/[email protected]
with:
python-version: '3.10'
- name: Install dependencies
run: pip install -r requirements_dev.txt
- name: Run tests
run: python -m unittest discover -s tibberios/tests
13 changes: 11 additions & 2 deletions tibberios/cli.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .core import Config, Database, TibberConnector
from .visualization import GenerateViz
from datetime import datetime
from pprint import pprint


async def main(
Expand Down Expand Up @@ -39,7 +40,15 @@ async def main(
tib = TibberConnector(config.tibber_api_key)
price_data = await tib.get_price_data(resolution=resolution, records=records)

db.create_table()
tbl_name = "consumption"
columns = {
"start_time": "DATE PRIMARY KEY",
"unit_price": "REAL",
"total_cost": "REAL",
"cost": "REAL",
"consumption": "REAL",
}
db.create_table(name=tbl_name, cols_n_types=columns)
if verbose:
print("Consumption table created")
db.upsert_table(values=price_data.price_table)
Expand All @@ -49,7 +58,7 @@ async def main(
if verbose:
print("Consumption values upserted")
print("Latest 10 consumption values:")
db.show_latest_data()
pprint(db.get_latest_data(name=tbl_name, order="start_time"))

# TODO: make into subcommand using Python click
if generate_vis:
Expand Down
142 changes: 105 additions & 37 deletions tibberios/core.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from dataclasses import dataclass
from pprint import pprint
from collections import Counter

import tibber
from aiohttp import ClientSession

from os import remove


@dataclass
class Config:
Expand Down Expand Up @@ -127,7 +128,15 @@ async def get_price_data(self, resolution: str, records: int) -> PriceData:


class Database:
"""For smoother SQLite3 database operations."""

def __init__(self, filename: str) -> None:
"""An interface for performing common database operations on a SQLite3 database.

Args
----
filename: The path to the SQLite3 database file. Will get created if it doesn't exist.
"""
from sqlite3 import connect

self._database_path = filename
Expand All @@ -137,62 +146,121 @@ def __del__(self) -> None:
if hasattr(self, "connection") and self.connection:
self.close()

def create_table(self) -> None:
# TODO: Make generic
query = """
CREATE TABLE IF NOT EXISTS consumption(
start_time DATE PRIMARY KEY,
unit_price REAL,
total_cost REAL,
cost REAL,
consumption REAL
def delete_database(self) -> bool:
"""DESTROY THE DATABASE.

Returns
-------
bool: True if the database was deleted, raises exception otherwise.
"""
remove(self._database_path)
return True

def create_table(self, name: str, cols_n_types: dict) -> None:
"""Create a table in the database.

Args
----
name: The table name
cols_n_types: A dictionary with keys as the column names and values as SQLite data types
"""
query = f"""
CREATE TABLE IF NOT EXISTS {name} (
{','.join([c + " " + t for c, t in cols_n_types.items()])}
);
"""
cursor = self.connection
cursor = cursor.execute(query)
self.connection.commit()

def show_latest_data(self, limit: int = 10) -> None:
# TODO: Make generic
def get_latest_data(self, name: str, order: str, limit: int = 10) -> list[tuple]:
"""Get the latest values from a table.

Args
----
name: The table name
order: The column name by which to order the results
limit: The number of results to return

Returns
-------
list: The latest data as queried from the database table,
as a list of tuples where each row is represented in a tuple.
"""
query = f"""
SELECT *
FROM consumption
ORDER BY start_time DESC
FROM {name}
ORDER BY {order} DESC
LIMIT {limit};
"""
cursor = self.connection
cursor = cursor.execute(query)
pprint(cursor.fetchall())

def upsert_table(self, values: list[tuple]) -> None:
# TODO: Make generic
query = """
INSERT INTO consumption(
start_time
, unit_price
, total_cost
, cost
, consumption
return cursor.fetchall()

def insert_table(self, name: str, columns: list[str], values: list[tuple]) -> None:
"""Insert values to a table

Args
----
name: The table name
columns: The names of the columns in the table
values: The values to insert in the table
"""
query = f"""
INSERT INTO {name} (
{','.join(columns)}
)
VALUES(?, ?, ?, ?, ?)
ON CONFLICT(start_time) DO UPDATE SET
unit_price = excluded.unit_price
, total_cost = excluded.total_cost
, cost = excluded.cost
, consumption = excluded.consumption;
VALUES ({','.join('?'*len(columns))})
"""
cursor = self.connection
cursor.executemany(query, values)
self.connection.commit()

def upsert_table(
self, name: str, columns: list[str], values: list[tuple], pk: str
) -> None:
"""Upsert a table aka insert values and overwrite if pk already has values.

Args
----
name: The table name
columns: The names of the columns in the table
values: The values to insert in the table
pk: The private key of the table
"""
n_cols = len(columns)
for i, v in enumerate(values):
assert n_cols == len(
v
), f"Row {i} in received values contains {len(v)} values, expected {n_cols}"

cols_to_update = set(columns) - set([pk])
query = f"""
INSERT INTO {name} (
{','.join(columns)}
)
VALUES ({','.join('?'*len(columns))})
ON CONFLICT ({pk}) DO UPDATE SET
{','.join([f'{c} = excluded.{c}' for c in cols_to_update])}
;
"""
cursor = self.connection
cursor.executemany(query, values)
self.connection.commit()

def delete_null_rows(self) -> None:
# TODO: Make generic
query = """
def delete_null_rows(self, name: str, pk: str) -> None:
"""Delete rows where the pk

Args
----
name: The table name
pk: The private key of the table
"""
query = f"""
DELETE
FROM consumption
WHERE start_time IS NULL
OR trim(start_time) = '';
FROM {name}
WHERE {pk} IS NULL
OR trim({pk}) = '';
"""
cursor = self.connection
cursor.execute(query)
Expand Down
57 changes: 57 additions & 0 deletions tibberios/tests/test_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import unittest
from os import path
from sqlite3 import OperationalError

from tibberios.core import Database


class TestDatabaseMethods(unittest.TestCase):
def setUp(self):
self.db = Database("test_tibberios.db")
self.name = "test"
self.cols_n_types = {
"col1": "REAL PRIMARY KEY",
"col2": "REAL",
"col3": "REAL",
}
self.columns = self.cols_n_types.keys()
self.values = [(1, 2, 3), (4, 5, 6), (7, 8, 9)]
self.new_values = [(1, 3, 4), (4, 6, 7), (7, 9, 9)]
self.pk = "col1"

def test_database_connection(self):
# make sure the database connection works
with self.assertRaises(OperationalError):
_ = self.db.get_latest_data(name="nonexistingtable", order="bad_column")

def test_database_file_created(self):
self.assertTrue(
path.exists(self.db._database_path),
)

def test_database_operations(self):
# table creation
self.db.create_table(name=self.name, cols_n_types=self.cols_n_types)
results = self.db.get_latest_data(name=self.name, order=self.pk)
# table should be empty
self.assertEqual(len(results), 0)

# add data to table
self.db.insert_table(name=self.name, columns=self.columns, values=self.values)
results = self.db.get_latest_data(name=self.name, order=self.pk)
self.assertEqual(len(self.values), len(results))

# upsert table
self.db.upsert_table(
name=self.name, columns=self.columns, values=self.values, pk=self.pk
)
results = self.db.get_latest_data(name=self.name, order=self.pk)
# table should have values
self.assertEqual(len(results), len(self.values))

def tearDown(self):
self.db.delete_database()


if __name__ == "__main__":
unittest.main()