Skip to content

Commit

Permalink
MobilityData#898 Write to CSV file incrementally
Browse files Browse the repository at this point in the history
  • Loading branch information
sylvansson committed Feb 2, 2025
1 parent aeb5585 commit 8ad6712
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 77 deletions.
121 changes: 64 additions & 57 deletions functions-python/export_csv/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# limitations under the License.
#
import argparse
import csv
import logging
import os
import re

import pandas as pd
from typing import Dict, Iterator

from dotenv import load_dotenv
import functions_framework
Expand All @@ -30,53 +30,48 @@

from shared.helpers.logger import Logger
from shared.database_gen.sqlacodegen_models import Gtfsfeed, Gtfsrealtimefeed
from collections import OrderedDict
from shared.common.db_utils import get_all_gtfs_rt_feeds_query, get_all_gtfs_feeds_query

from shared.helpers.database import Database

load_dotenv()
csv_default_file_path = "./output.csv"


class DataCollector:
"""
A class used to collect and organize data into rows and headers for CSV output.
One particularity of this class is that it uses an OrderedDict to store the data, so that the order of the columns
is preserved when writing to CSV.
"""

def __init__(self):
self.data = OrderedDict()
self.rows = []
self.headers = []

def add_data(self, key, value):
if key not in self.headers:
self.headers.append(key)
self.data[key] = value

def finalize_row(self):
self.rows.append(self.data.copy())
self.data = OrderedDict()

def write_csv_to_file(self, csv_file_path):
df = pd.DataFrame(self.rows, columns=self.headers)
df.to_csv(csv_file_path, index=False)

def get_dataframe(self) -> pd:
return pd.DataFrame(self.rows, columns=self.headers)
# This needs to be updated if we add fields to either `get_feed_csv_data` or
# `get_gtfs_rt_feed_csv_data`, otherwise the extra field(s) will be excluded from
# the generated CSV file.
headers = [
"id",
"data_type",
"entity_type",
"location.country_code",
"location.subdivision_name",
"location.municipality",
"provider",
"name",
"note",
"feed_contact_email",
"static_reference",
"urls.direct_download",
"urls.authentication_type",
"urls.authentication_info",
"urls.api_key_parameter_name",
"urls.latest",
"urls.license",
"location.bounding_box.minimum_latitude",
"location.bounding_box.maximum_latitude",
"location.bounding_box.minimum_longitude",
"location.bounding_box.maximum_longitude",
"location.bounding_box.extracted_on",
"status",
"features",
"redirect.id",
"redirect.comment",
]


@functions_framework.http
def export_and_upload_csv(request=None):
csv_file_path = csv_default_file_path
response = export_csv(csv_file_path)
upload_file_to_storage(csv_file_path, "sources_v2.csv")
return response


def export_csv(csv_file_path: str):
"""
HTTP Function entry point Reads the DB and outputs a csv file with feeds data.
This function requires the following environment variables to be set:
Expand All @@ -85,16 +80,36 @@ def export_csv(csv_file_path: str):
:return: HTTP response object
"""
Logger.init_logger()
logging.info("Function Started")
data_collector = collect_data()
data_collector.write_csv_to_file(csv_file_path)
return f"Exported {len(data_collector.rows)} feeds to CSV file {csv_file_path}."
logging.info("Export started")

csv_file_path = csv_default_file_path
export_csv(csv_file_path)
upload_file_to_storage(csv_file_path, "sources_v2.csv")

logging.info("Export successful")
return "Export successful"


def export_csv(csv_file_path: str):
"""
Write feed data to a local CSV file.
"""
with open(csv_file_path, "w") as out:
writer = csv.DictWriter(out, fieldnames=headers)
writer.writeheader()

def collect_data() -> DataCollector:
count = 0
for feed in fetch_feeds():
writer.writerow(feed)
count += 1

logging.info(f"Exported {count} feeds to CSV file {csv_file_path}.")


def fetch_feeds() -> Iterator[Dict]:
"""
Collect data from the DB and write the output to a DataCollector.
:return: A filled DataCollector
Fetch and return feed data from the DB.
:return: Data to write to the output CSV file.
"""
db = Database(database_url=os.getenv("FEEDS_DATABASE_URL"))
logging.info(f"Using database {db.database_url}")
Expand All @@ -118,27 +133,19 @@ def collect_data() -> DataCollector:

logging.info(f"Retrieved {len(gtfs_rt_feeds)} GTFS realtime feeds.")

data_collector = DataCollector()

for feed in gtfs_feeds:
data = get_feed_csv_data(feed)
yield get_feed_csv_data(feed)

for key, value in data.items():
data_collector.add_data(key, value)
data_collector.finalize_row()
logging.info(f"Processed {len(gtfs_feeds)} GTFS feeds.")

for feed in gtfs_rt_feeds:
data = get_gtfs_rt_feed_csv_data(feed)
for key, value in data.items():
data_collector.add_data(key, value)
data_collector.finalize_row()
yield get_gtfs_rt_feed_csv_data(feed)

logging.info(f"Processed {len(gtfs_rt_feeds)} GTFS realtime feeds.")

except Exception as error:
logging.error(f"Error retrieving feeds: {error}")
raise Exception(f"Error retrieving feeds: {error}")
return data_collector


def extract_numeric_version(version):
Expand Down Expand Up @@ -233,8 +240,8 @@ def get_feed_csv_data(feed: Gtfsfeed):
"location.bounding_box.maximum_latitude": maximum_latitude,
"location.bounding_box.minimum_longitude": minimum_longitude,
"location.bounding_box.maximum_longitude": maximum_longitude,
"location.bounding_box.extracted_on": validated_at,
# We use the report validated_at date as the extracted_on date
"location.bounding_box.extracted_on": validated_at,
"status": feed.status,
"features": joined_features,
}
Expand Down
2 changes: 1 addition & 1 deletion functions-python/export_csv/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def populate_database():
feeds = []
# We create 3 feeds. The first one is active. The third one is inactive and redirected to the first one.
# The second one is active but not redirected.
# First fill the generic paramaters
# First fill the generic parameters
for i in range(3):
feed = Gtfsfeed(
data_type="gtfs",
Expand Down
25 changes: 6 additions & 19 deletions functions-python/export_csv/tests/test_export_csv_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,27 +39,14 @@ def test_export_csv():
os.environ[
"FEEDS_DATABASE_URL"
] = "postgresql://postgres:postgres@localhost:54320/MobilityDatabaseTest"
data_collector = main.collect_data()
print(f"Collected data for {len(data_collector.rows)} feeds.")

df_extracted = data_collector.get_dataframe()
csv_file_path = "./output.csv"
main.export_csv(csv_file_path)
df_actual = pd.read_csv(csv_file_path)
print(f"Collected data for {len(df_actual)} feeds.")

csv_buffer = io.StringIO(expected_csv)
df_from_expected_csv = pd.read_csv(csv_buffer)
df_from_expected_csv.fillna("", inplace=True)

df_extracted.fillna("", inplace=True)

df_extracted["urls.authentication_type"] = df_extracted[
"urls.authentication_type"
].astype(str)
df_from_expected_csv["urls.authentication_type"] = df_from_expected_csv[
"urls.authentication_type"
].astype(str)
df_from_expected_csv["location.bounding_box.extracted_on"] = pd.to_datetime(
df_from_expected_csv["location.bounding_box.extracted_on"], utc=True
)
df_expected = pd.read_csv(io.StringIO(expected_csv))

# try:
pdt.assert_frame_equal(df_extracted, df_from_expected_csv)
pdt.assert_frame_equal(df_actual, df_expected)
print("DataFrames are equal.")

0 comments on commit 8ad6712

Please sign in to comment.