diff --git a/api/src/database/database.py b/api/src/database/database.py index 080507677..ce016542c 100644 --- a/api/src/database/database.py +++ b/api/src/database/database.py @@ -5,9 +5,9 @@ from typing import Type, Callable from dotenv import load_dotenv from sqlalchemy import create_engine, inspect -from sqlalchemy.orm import load_only, Query +from sqlalchemy.orm import load_only, Query, class_mapper -from database_gen.sqlacodegen_models import Base +from database_gen.sqlacodegen_models import Base, Feed, Gtfsfeed, Gtfsrealtimefeed from sqlalchemy.orm import sessionmaker import logging from typing import Final @@ -26,6 +26,24 @@ def generate_unique_id() -> str: return str(uuid.uuid4()) +def configure_polymorphic_mappers(): + """ + Configure the polymorphic mappers allowing polymorphic values on relationships. + """ + feed_mapper = class_mapper(Feed) + # Configure the polymorphic mapper using date_type as discriminator for the Feed class + feed_mapper.polymorphic_on = Feed.data_type + feed_mapper.polymorphic_identity = Feed.__tablename__.lower() + + gtfsfeed_mapper = class_mapper(Gtfsfeed) + gtfsfeed_mapper.inherits = feed_mapper + gtfsfeed_mapper.polymorphic_identity = Gtfsfeed.__tablename__.lower() + + gtfsrealtimefeed_mapper = class_mapper(Gtfsrealtimefeed) + gtfsrealtimefeed_mapper.inherits = feed_mapper + gtfsrealtimefeed_mapper.polymorphic_identity = Gtfsrealtimefeed.__tablename__.lower() + + class Database: """ This class represents a database instance @@ -40,11 +58,12 @@ def __new__(cls, *args, **kwargs): cls.instance = object.__new__(cls) return cls.instance - def __init__(self): + def __init__(self, echo_sql=True): load_dotenv() self.engine = None self.connection_attempts = 0 self.SQLALCHEMY_DATABASE_URL = os.getenv("FEEDS_DATABASE_URL") + self.echo_sql = echo_sql self.start_session() def is_connected(self): @@ -81,7 +100,7 @@ def start_new_db_session(self): raise Exception("Database URL is not set") else: logging.info("Starting new global database session.") - self.engine = create_engine(self.SQLALCHEMY_DATABASE_URL, echo=True) + self.engine = create_engine(self.SQLALCHEMY_DATABASE_URL, echo=self.echo_sql) global_session = sessionmaker(bind=self.engine)() self.session = global_session return global_session diff --git a/api/src/scripts/populate_db.py b/api/src/scripts/populate_db.py index 84d2b2e85..d7f40666d 100644 --- a/api/src/scripts/populate_db.py +++ b/api/src/scripts/populate_db.py @@ -1,13 +1,13 @@ import argparse import os from pathlib import Path -from queue import PriorityQueue +from typing import Type import pandas from dotenv import load_dotenv -from sqlalchemy import inspect, text +from sqlalchemy import text -from database.database import Database, generate_unique_id +from database.database import Database, generate_unique_id, configure_polymorphic_mappers from database_gen.sqlacodegen_models import ( Entitytype, Externalid, @@ -15,12 +15,17 @@ Gtfsrealtimefeed, Location, Redirectingid, - Base, t_feedsearch, + Feed, ) from utils.data_utils import set_up_defaults from utils.logger import Logger +import logging + +logging.basicConfig() +logging.getLogger("sqlalchemy.engine").setLevel(logging.ERROR) + def set_up_configs(): """ @@ -43,176 +48,135 @@ class DatabasePopulateHelper: def __init__(self, filepath): self.logger = Logger(self.__class__.__module__).get_logger() - self.db = Database() + self.logger.setLevel(logging.INFO) + self.db = Database(echo_sql=False) self.df = pandas.read_csv(filepath) # contains the data to populate the database # Filter unsupported data types self.df = self.df[(self.df.data_type == "gtfs") | (self.df.data_type == "gtfs-rt")] self.df = set_up_defaults(self.df) - self.logger.info(self.df) - def fast_merge(self, orm_object: Base): + @staticmethod + def get_model(data_type: str | None) -> Type[Gtfsrealtimefeed | Gtfsfeed | Feed]: """ - Faster merge of an orm object that strictly validates the PK in the active session - This method assumes that the object is clean i.e. not present in the database - :param orm_object: the object to merge - :return: True if merge was successful, False otherwise + Get the model based on the data type """ - try: - # Check if an object with the same primary key is already in the session - primary_key = inspect(orm_object.__class__).primary_key - existing_object = None - if primary_key: - conditions = [pk == getattr(orm_object, pk.name) for pk in primary_key] - existing_objects = self.db.select_from_active_session(orm_object.__class__, conditions) - if len(existing_objects) == 1: - existing_object = existing_objects[0] - - if existing_object: - # If an object with the same primary key exists, update it with the new data - for attr, value in orm_object.__dict__.items(): - if attr != "_sa_instance_state": - setattr(existing_object, attr, value) - return True - else: - # Otherwise simply add the object without loading - return self.db.session.add(orm_object) - except Exception as e: - self.logger.error(f"Fast merge query failed with exception: \n{e}") - return False + if data_type is None: + return Feed + return Gtfsrealtimefeed if data_type == "gtfs_rt" else Gtfsfeed - def populate(self): + @staticmethod + def get_safe_value(row, column_name, default_value): """ - Populates the database + Get a safe value from the row """ - entities = [] # entities to add to the database - entities_index = PriorityQueue() # prioritization of the entities to avoid FK violation - - def add_entity(entity, priority): - # validate that entity is not already added - primary_key = inspect(entity.__class__).primary_key - for e in [sim_entity for sim_entity in entities if isinstance(sim_entity, type(entity))]: - entities_are_equal = all([getattr(e, pk.name) == getattr(entity, pk.name) for pk in primary_key]) - if entities_are_equal: - return e - # add the entity - entities_index.put((priority, len(entities))) - entities.append(entity) - return entity - - if self.df is None: - return + if not row[column_name] or pandas.isna(row[column_name]) or f"{row[column_name]}".strip() == "": + return default_value if default_value is not None else None + return f"{row[column_name]}".strip() - # Gather all stable IDs from csv - stable_ids = [f"mdb-{int(row['mdb_source_id'])}" for index, row in self.df.iterrows()] - - # Query once to get all existing feeds information - gtfs_feeds = self.db.session.query(Gtfsfeed).filter(Gtfsfeed.stable_id.in_(stable_ids)).all() - gtfs_rt_feeds = self.db.session.query(Gtfsrealtimefeed).filter(Gtfsrealtimefeed.stable_id.in_(stable_ids)).all() - locations = self.db.session.query(Location) - redirects = self.db.session.query(Redirectingid) - entity_types = self.db.session.query(Entitytype) - - # Keep dicts (maps) of id -> entity, so we can reference the feed information when processing - feed_map = {feed.stable_id: feed for feed in gtfs_feeds + gtfs_rt_feeds} - locations_map = {location.id: location for location in locations} - entity_types_map = {entity_type.name: entity_type for entity_type in entity_types} - redirects_set = {(redirect_entity.source_id, redirect_entity.target_id) for redirect_entity in redirects} - - for index, row in self.df.iterrows(): - mdb_id = f"mdb-{int(row['mdb_source_id'])}" - self.logger.debug(f"Populating Database for with Feed [stable_id = {mdb_id}]") + def get_data_type(self, row): + """ + Get the data type from the row + """ + data_type = self.get_safe_value(row, "data_type", "").lower() + if data_type not in ["gtfs", "gtfs-rt", "gtfs_rt"]: + self.logger.warning(f"Unsupported data type: {data_type}") + return None + return data_type.replace("-", "_") - feed_exists = mdb_id in feed_map - if not feed_exists: - self.logger.info(f"New {row['data_type']} feed with stable_id = {mdb_id} has been added.") + def query_feed_by_stable_id(self, stable_id: str, data_type: str | None) -> Gtfsrealtimefeed | Gtfsfeed | None: + """ + Query the feed by stable id + """ + model = self.get_model(data_type) + return self.db.session.query(model).filter(model.stable_id == stable_id).first() - # Feed - feed_class = Gtfsfeed if row["data_type"] == "gtfs" else Gtfsrealtimefeed - feed = ( - feed_map[mdb_id] - if mdb_id in feed_map - else feed_class( - id=generate_unique_id(), - stable_id=mdb_id, - ) - ) - feed_map[mdb_id] = feed - feed.data_type = row["data_type"] - feed.feed_name = row["name"] - feed.note = row["note"] - feed.producer_url = row["urls.direct_download"] - feed.authentication_type = str(int(row.get("urls.authentication_type", "0") or "0")) - feed.authentication_info_url = row["urls.authentication_info"] - feed.api_key_parameter_name = row["urls.api_key_parameter_name"] - feed.license_url = row["urls.license"] - feed.status = row["status"] - feed.provider = row["provider"] - feed.feed_contact_email = row["feed_contact_email"] + def get_stable_id(self, row): + """ + Get the stable id from the row + """ + return f'mdb-{self.get_safe_value(row, "mdb_source_id", "")}' - # Location - country_code = row["location.country_code"] - subdivision_name = row["location.subdivision_name"] - municipality = row["location.municipality"] - composite_id = f"{country_code}-{subdivision_name}-{municipality}".replace(" ", "_") - location_id = composite_id if len(composite_id) > 0 else "unknown" + def populate_location(self, feed, row, stable_id): + """ + Populate the location for the feed + """ + country_code = self.get_safe_value(row, "location.country_code", "") + subdivision_name = self.get_safe_value(row, "location.subdivision_name", "") + municipality = self.get_safe_value(row, "location.municipality", "") + composite_id = f"{country_code}-{subdivision_name}-{municipality}".replace(" ", "_") + location_id = composite_id if len(composite_id) > 2 else None + if not location_id: + self.logger.warning(f"Location ID is empty for feed {stable_id}") + feed.locations.clear() + else: + location = self.db.session.get(Location, location_id) location = ( - locations_map[location_id] - if location_id in locations_map + location + if location else Location( id=location_id, - country_code=country_code if country_code != "" else None, - subdivision_name=subdivision_name if subdivision_name != "" else None, - municipality=municipality if municipality != "" else None, + country_code=country_code, + subdivision_name=subdivision_name, + municipality=municipality, ) ) - location = add_entity(location, 1) - feed.locations.append(location) - add_entity(feed, 1 if isinstance(feed, Gtfsfeed) else 3) - - if feed.data_type == "gtfs_rt": - # Entity Type and Entity Type x GTFSRealtimeFeed relationship - for entity_type_name in row["entity_type"].replace("|", "-").split("-"): - if len(entity_type_name) == 0: - continue - if entity_type_name not in entity_types_map: - entity_types_map[entity_type_name] = Entitytype(name=entity_type_name) - entity_type = entity_types_map[entity_type_name] - entity_type.feeds.append(feed) + feed.locations = [location] - # External ID - if not feed_exists: - mdb_external_id = Externalid( - feed_id=feed.id, - associated_id=str(int(row["mdb_source_id"])), - source="mdb", - ) - add_entity(mdb_external_id, 4) - - [add_entity(entity_type, 4) for entity_type in entity_types_map.values()] - - # Iterate again over the contents of the csv files to process the feed references. + def process_entity_types(self, feed: Gtfsrealtimefeed, row, stable_id): + """ + Process the entity types for the feed + """ + entity_types = self.get_safe_value(row, "entity_type", "").replace("|", "-").split("-") + if len(entity_types) > 0: + for entity_type_name in entity_types: + entity_type = self.db.session.query(Entitytype).filter(Entitytype.name == entity_type_name).first() + if not entity_type: + entity_type = Entitytype(name=entity_type_name) + if all(entity_type.name != entity.name for entity in feed.entitytypes): + feed.entitytypes.append(entity_type) + self.db.session.flush() + else: + self.logger.warning(f"Entity types array is empty for feed {stable_id}") + feed.entitytypes.clear() + + def process_feed_references(self): + """ + Process the feed references + """ + self.logger.info("Processing feed references") for index, row in self.df.iterrows(): - mdb_id = f"mdb-{int(row['mdb_source_id'])}" - feed = feed_map[mdb_id] - if row["data_type"] == "gtfs_rt": - # Feed Reference - if row["static_reference"] is not None: - static_reference_mdb_id = f"mdb-{int(row['static_reference'])}" - referenced_feed = feed_map.get(static_reference_mdb_id, None) - already_referenced_ids = {ref.id for ref in referenced_feed.gtfs_rt_feeds} - if referenced_feed and feed.id not in already_referenced_ids: - referenced_feed.gtfs_rt_feeds.append(feed) - - # Process redirects + stable_id = self.get_stable_id(row) + data_type = self.get_data_type(row) + if data_type != "gtfs_rt": + continue + gtfs_rt_feed = self.query_feed_by_stable_id(stable_id, "gtfs_rt") + static_reference = self.get_safe_value(row, "static_reference", "") + if static_reference: + gtfs_stable_id = f"mdb-{int(float(static_reference))}" + gtfs_feed = self.query_feed_by_stable_id(gtfs_stable_id, "gtfs") + already_referenced_ids = {ref.id for ref in gtfs_feed.gtfs_rt_feeds} + if gtfs_feed and gtfs_rt_feed.id not in already_referenced_ids: + gtfs_feed.gtfs_rt_feeds.append(gtfs_rt_feed) + # Flush to avoid FK violation + self.db.session.flush() + + def process_redirects(self): + """ + Process the redirects + """ + self.logger.info("Processing redirects") + for index, row in self.df.iterrows(): + stable_id = self.get_stable_id(row) raw_redirects = row.get("redirect.id", None) redirects_ids = str(raw_redirects).split("|") if raw_redirects is not None else [] + if len(redirects_ids) == 0: + continue + feed = self.query_feed_by_stable_id(stable_id, None) raw_comments = row.get("redirect.comment", None) comments = raw_comments.split("|") if raw_comments is not None else [] - - if len(redirects_ids) != len(comments): - self.logger.warning(f"Number of redirect ids and redirect comments differ for feed {mdb_id}") - + if len(redirects_ids) != len(comments) and len(comments) > 0: + self.logger.warning(f"Number of redirect ids and redirect comments differ for feed {stable_id}") for mdb_source_id in redirects_ids: if len(mdb_source_id) == 0: # since there is a 1:1 correspondence between redirect ids and comments, skip also the comment @@ -223,35 +187,82 @@ def add_entity(entity, priority): else: comment = "" - target_stable_id = f"mdb-{int(float(mdb_source_id))}" - target_feed = feed_map.get(target_stable_id, None) + target_stable_id = f"mdb-{int(float(mdb_source_id.strip()))}" + target_feed = self.query_feed_by_stable_id(target_stable_id, None) + if not target_feed: + self.logger.warning(f"Could not find redirect target feed {target_stable_id} for feed {stable_id}") + continue - if target_feed: - if target_feed.id != feed.id and (feed.id, target_feed.id) not in redirects_set: - redirect = Redirectingid(source_id=feed.id, target_id=target_feed.id, redirect_comment=comment) - add_entity(redirect, 5) - else: - self.logger.error(f"Feed has redirect pointing to itself {mdb_id}") + if feed.id == target_feed.id: + self.logger.error(f"Feed has redirect pointing to itself {stable_id}") else: - self.logger.warning(f"Could not find redirect target feed {target_stable_id} for feed {mdb_id}") - - priority = 1 - while not entities_index.empty(): - next_priority, entity_index = entities_index.get() - if priority != next_priority: - self.logger.debug(f"Flushing for priority {priority}") - priority = next_priority - self.db.flush() - self.fast_merge(entities[entity_index]) - - self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Started") - self.db.session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {t_feedsearch.name}")) - self.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Completed") - - self.db.commit() + if all(redirect.target_id != target_feed.id for redirect in feed.redirectingids): + feed.redirectingids.append( + Redirectingid(source_id=feed.id, target_id=target_feed.id, redirect_comment=comment) + ) + # Flush to avoid FK violation + self.db.session.flush() + + def populate_db(self): + """ + Populate the database with the sources.csv data + """ + self.logger.info("Populating the database with sources.csv data") + for index, row in self.df.iterrows(): + self.logger.debug(f"Populating Database with Feed [stable_id = {row['mdb_source_id']}]") + # Create or update the GTFS feed + data_type = self.get_data_type(row) + stable_id = self.get_stable_id(row) + feed = self.query_feed_by_stable_id(stable_id, data_type) + if feed: + self.logger.debug(f"Updating {feed.__class__.__name__}: {stable_id}") + else: + feed = self.get_model(data_type)(id=generate_unique_id(), data_type=data_type, stable_id=stable_id) + self.logger.info(f"Creating {feed.__class__.__name__}: {stable_id}") + self.db.session.add(feed) + feed.externalids = [ + Externalid( + feed_id=feed.id, + associated_id=str(int(float(row["mdb_source_id"]))), + source="mdb", + ) + ] + # Populate common fields from Feed + feed.feed_name = self.get_safe_value(row, "name", "") + feed.note = self.get_safe_value(row, "note", "") + feed.producer_url = self.get_safe_value(row, "urls.direct_download", "") + feed.authentication_type = str(int(float(self.get_safe_value(row, "urls.authentication_type", "0")))) + feed.authentication_info_url = self.get_safe_value(row, "urls.authentication_info", "") + feed.api_key_parameter_name = self.get_safe_value(row, "urls.api_key_parameter_name", "") + feed.license_url = self.get_safe_value(row, "urls.license", "") + feed.status = self.get_safe_value(row, "status", "active") + feed.feed_contact_email = self.get_safe_value(row, "feed_contact_email", "") + feed.provider = self.get_safe_value(row, "provider", "") + + self.populate_location(feed, row, stable_id) + if data_type == "gtfs_rt": + self.process_entity_types(feed, row, stable_id) + + self.db.session.add(feed) + self.db.session.flush() + # This need to be done after all feeds are added to the session to avoid FK violation + self.process_feed_references() + self.process_redirects() if __name__ == "__main__": - filepath = set_up_configs() - db_helper = DatabasePopulateHelper(filepath) - db_helper.populate() + db_helper = DatabasePopulateHelper(set_up_configs()) + try: + configure_polymorphic_mappers() + db_helper.populate_db() + db_helper.db.session.commit() + + db_helper.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Started") + db_helper.db.session.execute(text(f"REFRESH MATERIALIZED VIEW CONCURRENTLY {t_feedsearch.name}")) + db_helper.logger.info("Refreshing MATERIALIZED FEED SEARCH VIEW - Completed") + db_helper.db.session.commit() + db_helper.logger.info("\n----- Database populated with sources.csv data. -----") + except Exception as e: + db_helper.logger.error(f"\n------ Failed to populate the database with sources.csv: {e} -----\n") + db_helper.db.session.rollback() + exit(1)