From 31895f412eba60e274c7158a39cf0a8455ffd671 Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Mon, 3 Oct 2022 00:07:52 -0700 Subject: [PATCH] chore(tags): Refactor logic to leverage Flask-SQLAlchemy extension (#21459) --- superset/cli/update.py | 7 ++-- superset/common/tags.py | 76 +++++++++++++++++++++++------------------ 2 files changed, 45 insertions(+), 38 deletions(-) diff --git a/superset/cli/update.py b/superset/cli/update.py index e2054485c6367..bdc54db3a9c99 100755 --- a/superset/cli/update.py +++ b/superset/cli/update.py @@ -30,7 +30,6 @@ from flask_appbuilder.api.manager import resolver import superset.utils.database as database_utils -from superset.extensions import db from superset.utils.encrypt import SecretsMigrator logger = logging.getLogger(__name__) @@ -62,9 +61,9 @@ def sync_tags() -> None: # pylint: disable=import-outside-toplevel from superset.common.tags import add_favorites, add_owners, add_types - add_types(db.engine, metadata) - add_owners(db.engine, metadata) - add_favorites(db.engine, metadata) + add_types(metadata) + add_owners(metadata) + add_favorites(metadata) @click.command() diff --git a/superset/common/tags.py b/superset/common/tags.py index d85a33b84e140..706192913a1c3 100644 --- a/superset/common/tags.py +++ b/superset/common/tags.py @@ -17,15 +17,15 @@ from typing import Any, List from sqlalchemy import MetaData -from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import and_, func, join, literal, select +from superset.extensions import db from superset.tags.models import ObjectTypes, TagTypes def add_types_to_charts( - engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] ) -> None: slices = metadata.tables["slices"] @@ -53,11 +53,11 @@ def add_types_to_charts( .where(tagged_object.c.tag_id.is_(None)) ) query = tagged_object.insert().from_select(columns, charts) - engine.execute(query) + db.session.execute(query) def add_types_to_dashboards( - engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] ) -> None: dashboard_table = metadata.tables["dashboards"] @@ -85,11 +85,11 @@ def add_types_to_dashboards( .where(tagged_object.c.tag_id.is_(None)) ) query = tagged_object.insert().from_select(columns, dashboards) - engine.execute(query) + db.session.execute(query) def add_types_to_saved_queries( - engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] ) -> None: saved_query = metadata.tables["saved_query"] @@ -117,11 +117,11 @@ def add_types_to_saved_queries( .where(tagged_object.c.tag_id.is_(None)) ) query = tagged_object.insert().from_select(columns, saved_queries) - engine.execute(query) + db.session.execute(query) def add_types_to_datasets( - engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] ) -> None: tables = metadata.tables["tables"] @@ -149,10 +149,10 @@ def add_types_to_datasets( .where(tagged_object.c.tag_id.is_(None)) ) query = tagged_object.insert().from_select(columns, datasets) - engine.execute(query) + db.session.execute(query) -def add_types(engine: Engine, metadata: MetaData) -> None: +def add_types(metadata: MetaData) -> None: """ Tag every object according to its type: @@ -222,18 +222,22 @@ def add_types(engine: Engine, metadata: MetaData) -> None: insert = tag.insert() for type_ in ObjectTypes.__members__: try: - engine.execute(insert, name=f"type:{type_}", type=TagTypes.type) + db.session.execute( + insert, + name=f"type:{type_}", + type=TagTypes.type, + ) except IntegrityError: pass # already exists - add_types_to_charts(engine, metadata, tag, tagged_object, columns) - add_types_to_dashboards(engine, metadata, tag, tagged_object, columns) - add_types_to_saved_queries(engine, metadata, tag, tagged_object, columns) - add_types_to_datasets(engine, metadata, tag, tagged_object, columns) + add_types_to_charts(metadata, tag, tagged_object, columns) + add_types_to_dashboards(metadata, tag, tagged_object, columns) + add_types_to_saved_queries(metadata, tag, tagged_object, columns) + add_types_to_datasets(metadata, tag, tagged_object, columns) def add_owners_to_charts( - engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] ) -> None: slices = metadata.tables["slices"] @@ -265,11 +269,11 @@ def add_owners_to_charts( .where(tagged_object.c.tag_id.is_(None)) ) query = tagged_object.insert().from_select(columns, charts) - engine.execute(query) + db.session.execute(query) def add_owners_to_dashboards( - engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] ) -> None: dashboard_table = metadata.tables["dashboards"] @@ -301,11 +305,11 @@ def add_owners_to_dashboards( .where(tagged_object.c.tag_id.is_(None)) ) query = tagged_object.insert().from_select(columns, dashboards) - engine.execute(query) + db.session.execute(query) def add_owners_to_saved_queries( - engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] ) -> None: saved_query = metadata.tables["saved_query"] @@ -337,11 +341,11 @@ def add_owners_to_saved_queries( .where(tagged_object.c.tag_id.is_(None)) ) query = tagged_object.insert().from_select(columns, saved_queries) - engine.execute(query) + db.session.execute(query) def add_owners_to_datasets( - engine: Engine, metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] + metadata: MetaData, tag: Any, tagged_object: Any, columns: List[str] ) -> None: tables = metadata.tables["tables"] @@ -373,10 +377,10 @@ def add_owners_to_datasets( .where(tagged_object.c.tag_id.is_(None)) ) query = tagged_object.insert().from_select(columns, datasets) - engine.execute(query) + db.session.execute(query) -def add_owners(engine: Engine, metadata: MetaData) -> None: +def add_owners(metadata: MetaData) -> None: """ Tag every object according to its owner: @@ -443,19 +447,19 @@ def add_owners(engine: Engine, metadata: MetaData) -> None: # create a custom tag for each user ids = select([users.c.id]) insert = tag.insert() - for (id_,) in engine.execute(ids): + for (id_,) in db.session.execute(ids): try: - engine.execute(insert, name=f"owner:{id_}", type=TagTypes.owner) + db.session.execute(insert, name=f"owner:{id_}", type=TagTypes.owner) except IntegrityError: pass # already exists - add_owners_to_charts(engine, metadata, tag, tagged_object, columns) - add_owners_to_dashboards(engine, metadata, tag, tagged_object, columns) - add_owners_to_saved_queries(engine, metadata, tag, tagged_object, columns) - add_owners_to_datasets(engine, metadata, tag, tagged_object, columns) + add_owners_to_charts(metadata, tag, tagged_object, columns) + add_owners_to_dashboards(metadata, tag, tagged_object, columns) + add_owners_to_saved_queries(metadata, tag, tagged_object, columns) + add_owners_to_datasets(metadata, tag, tagged_object, columns) -def add_favorites(engine: Engine, metadata: MetaData) -> None: +def add_favorites(metadata: MetaData) -> None: """ Tag every object that was favorited: @@ -484,9 +488,13 @@ def add_favorites(engine: Engine, metadata: MetaData) -> None: # create a custom tag for each user ids = select([users.c.id]) insert = tag.insert() - for (id_,) in engine.execute(ids): + for (id_,) in db.session.execute(ids): try: - engine.execute(insert, name=f"favorited_by:{id_}", type=TagTypes.type) + db.session.execute( + insert, + name=f"favorited_by:{id_}", + type=TagTypes.type, + ) except IntegrityError: pass # already exists @@ -518,4 +526,4 @@ def add_favorites(engine: Engine, metadata: MetaData) -> None: .where(tagged_object.c.tag_id.is_(None)) ) query = tagged_object.insert().from_select(columns, favstars) - engine.execute(query) + db.session.execute(query)