diff --git a/src/openstack_billing_db/fetch.py b/src/openstack_billing_db/fetch.py index b3ef768..ca588f5 100644 --- a/src/openstack_billing_db/fetch.py +++ b/src/openstack_billing_db/fetch.py @@ -4,6 +4,8 @@ import subprocess import boto3 +import requests +from requests.auth import HTTPBasicAuth logger = logging.getLogger(__name__) @@ -112,3 +114,57 @@ def convert_mysqldump_to_sqlite(path_to_dump) -> str: logger.info(f"Converted at {destination_path}.") return destination_path + + +def get_keycloak_session(): + """Authenticate as a client with Keycloak to receive an access token.""" + keycloak_token_url = os.getenv( + "KEYCLOAK_TOKEN_URL", + ("https://keycloak.mss.mghpcc.org/auth/realms/mss" + "/protocol/openid-connect/token") + ) + keycloak_client_id = os.getenv("KEYCLOAK_CLIENT_ID") + keycloak_client_secret = os.getenv("KEYCLOAK_CLIENT_SECRET") + + if not keycloak_client_id or not keycloak_client_secret: + raise Exception("Must provide KEYCLOAK_CLIENT_ID and" + " KEYCLOAK_CLIENT_SECRET environment variables.") + + r = requests.post( + keycloak_token_url, + data={"grant_type": "client_credentials"}, + auth=HTTPBasicAuth(keycloak_client_id, keycloak_client_secret), + ) + client_token = r.json()["access_token"] + + session = requests.session() + headers = { + "Authorization": f"Bearer {client_token}", + "Content-Type": "application/json", + } + session.headers.update(headers) + return session + +def download_coldfront_data(download_location=None) -> str: + """Downloads allocation data from the ColdFront API. + + Returns location of downloaded JSON file. + """ + + colfront_url = os.getenv("COLDFRONT_URL", + "https://coldfront.mss.mghpcc.org") + allocations_url = f"{colfront_url}/api/allocations?all=true" + + logger.info(f"Making request to ColdFront at {allocations_url}") + r = get_keycloak_session().get(allocations_url) + + if not r.status_code == 200: + raise Exception(f"{r.status_code} Error making API request to ColdFront.") + + if not download_location: + download_location = "/tmp/coldfront_data.json" + with open(download_location, "w") as f: + f.write(r.text) + + logger.info(f"Downloaded ColdFront data at {download_location}.") + return download_location diff --git a/src/openstack_billing_db/main.py b/src/openstack_billing_db/main.py index 9271f3e..7a426b9 100644 --- a/src/openstack_billing_db/main.py +++ b/src/openstack_billing_db/main.py @@ -40,7 +40,18 @@ def main(): "--coldfront-data-file", default=None, help=("Path to JSON Output of ColdFront's /api/allocations." - "Used for populating project names and PIs.") + "Used for populating project names and PIs. If" + " --download-coldfront-data option is applied, this" + " location will be used to save the downloaded output.") + ) + parser.add_argument( + "--download-coldfront-data", + default=False, + help=("Download ColdFront data from ColdFront. Requires the environment" + " variables KEYCLOAK_CLIENT_ID and KEYCLOAK_CLIENT_SECRET." + " Default to NERC Keycloak and ColdFront but can be" + " configure using KEYCLOAK_TOKEN_URL and COLDFRONT_URL environment" + " variables.") ) parser.add_argument( "--sql-dump-file", @@ -132,6 +143,13 @@ def main(): raise Exception("Must provide either --sql_dump_file" "or --download_dump_from_s3.") + coldfront_data_file = args.coldfront_data_file + if args.download_coldfront_data: + coldfront_data_file = fetch.download_coldfront_data(coldfront_data_file) + + if coldfront_data_file: + logger.info(f"Using ColdFront data file at {coldfront_data_file}.") + rates = billing.Rates( cpu=args.rate_cpu_su, gpu_a100=args.rate_gpu_a100_su, @@ -146,7 +164,7 @@ def main(): args.end, args.output, rates, - coldfront_data_file=args.coldfront_data_file, + coldfront_data_file=coldfront_data_file, invoice_month=args.invoice_month, upload_to_s3=args.upload_to_s3, sql_dump_file=dump_file,