Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add teardown task #1529

Merged
merged 14 commits into from
Feb 13, 2025
36 changes: 35 additions & 1 deletion cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from cosmos.config import RenderConfig
from cosmos.constants import (
DBT_SETUP_ASYNC_TASK_ID,
DBT_TEARDOWN_ASYNC_TASK_ID,
DEFAULT_DBT_RESOURCES,
SUPPORTED_BUILD_RESOURCES,
TESTABLE_DBT_RESOURCES,
Expand All @@ -25,7 +26,7 @@
from cosmos.core.graph.entities import Task as TaskMetadata
from cosmos.dbt.graph import DbtNode
from cosmos.log import get_logger
from cosmos.settings import enable_setup_async_task
from cosmos.settings import enable_setup_async_task, enable_teardown_async_task

logger = get_logger(__name__)

Expand Down Expand Up @@ -487,6 +488,37 @@ def calculate_detached_node_name(node: DbtNode) -> str:
return node_name


def _add_teardown_task(
dag: DAG,
execution_mode: ExecutionMode,
task_args: dict[str, Any],
tasks_map: dict[str, Any],
task_group: TaskGroup | None,
render_config: RenderConfig | None = None,
) -> None:
if execution_mode != ExecutionMode.AIRFLOW_ASYNC:
return

if render_config is not None:
task_args["select"] = render_config.select
task_args["selector"] = render_config.selector
task_args["exclude"] = render_config.exclude

teardown_task_metadata = TaskMetadata(
id=DBT_TEARDOWN_ASYNC_TASK_ID,
operator_class="cosmos.operators._asynchronous.TeardownAsyncOperator",
arguments=task_args,
extra_context={"dbt_dag_task_group_identifier": _get_dbt_dag_task_group_identifier(dag, task_group)},
)
teardown_airflow_task = create_airflow_task(teardown_task_metadata, dag, task_group=task_group)

for task_id, task in tasks_map.items():
if len(task.downstream_list) == 0:
task >> teardown_airflow_task

tasks_map[DBT_TEARDOWN_ASYNC_TASK_ID] = teardown_airflow_task


def build_airflow_graph(
nodes: dict[str, DbtNode],
dag: DAG, # Airflow-specific - parent DAG where to associate tasks and (optional) task groups
Expand Down Expand Up @@ -595,6 +627,8 @@ def build_airflow_graph(
create_airflow_task_dependencies(nodes, tasks_map)
if enable_setup_async_task:
_add_dbt_setup_async_task(dag, execution_mode, task_args, tasks_map, task_group, render_config=render_config)
if enable_teardown_async_task:
_add_teardown_task(dag, execution_mode, task_args, tasks_map, task_group, render_config=render_config)
return tasks_map


Expand Down
1 change: 1 addition & 0 deletions cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _missing_value_(cls, value): # type: ignore
TESTABLE_DBT_RESOURCES = {DbtResourceType.MODEL, DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED}

DBT_SETUP_ASYNC_TASK_ID = "dbt_setup_async"
DBT_TEARDOWN_ASYNC_TASK_ID = "dbt_teardown_async"

TELEMETRY_URL = "https://astronomer.gateway.scarf.sh/astronomer-cosmos/{telemetry_version}/{cosmos_version}/{airflow_version}/{python_version}/{platform_system}/{platform_machine}/{event_type}/{status}/{dag_hash}/{task_count}/{cosmos_task_count}"
TELEMETRY_VERSION = "v1"
Expand Down
12 changes: 12 additions & 0 deletions cosmos/operators/_asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,15 @@ def execute(self, context: Context, **kwargs: Any) -> None:
self.build_and_run_cmd(
context=context, cmd_flags=self.dbt_cmd_flags, run_as_async=True, async_context=async_context
)


class TeardownAsyncOperator(DbtRunOperator):
def __init__(self, *args: Any, **kwargs: Any):
kwargs["emit_datasets"] = False
super().__init__(*args, **kwargs)

def execute(self, context: Context, **kwargs: Any) -> Any:
async_context = {"profile_type": self.profile_config.get_profile_type(), "teardown_task": True}
self.build_and_run_cmd(
context=context, cmd_flags=self.dbt_cmd_flags, run_as_async=True, async_context=async_context
)
23 changes: 22 additions & 1 deletion cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@
from cosmos.dataset import get_dataset_alias_name
from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.exceptions import AirflowCompatibilityError, CosmosDbtRunError, CosmosValueError
from cosmos.settings import enable_setup_async_task, remote_target_path, remote_target_path_conn_id
from cosmos.settings import (
enable_setup_async_task,
enable_teardown_async_task,
remote_target_path,
remote_target_path_conn_id,
)

try:
from airflow.datasets import Dataset
Expand Down Expand Up @@ -338,6 +343,18 @@ def _upload_sql_files(self, tmp_project_dir: str, resource_type: str) -> None:
elapsed_time = time.time() - start_time
self.log.info("SQL files upload completed in %.2f seconds.", elapsed_time)

def _delete_sql_files(self, tmp_project_dir: Path, resource_type: str) -> None:
dest_target_dir, dest_conn_id = self._configure_remote_target_path()
source_run_dir = Path(tmp_project_dir) / f"target/{resource_type}"
files = [str(file) for file in source_run_dir.rglob("*") if file.is_file()]
from airflow.io.path import ObjectStoragePath

for file_path in files:
dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_run_dir, resource_type) # type: ignore
dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id)
dest_object_storage_path.unlink()
self.log.debug("Deleted %s to %s", file_path, dest_object_storage_path)

@provide_session
def store_freshness_json(self, tmp_project_dir: str, context: Context, session: Session = NEW_SESSION) -> None:
"""
Expand Down Expand Up @@ -466,6 +483,10 @@ def _handle_post_execution(self, tmp_project_dir: str, context: Context) -> None
self.callback(tmp_project_dir, **self.callback_args)

def _handle_async_execution(self, tmp_project_dir: str, context: Context, async_context: dict[str, Any]) -> None:
if async_context.get("teardown_task") and enable_teardown_async_task:
self._delete_sql_files(Path(tmp_project_dir), "run")
return

if enable_setup_async_task:
self._upload_sql_files(tmp_project_dir, "run")
else:
Expand Down
2 changes: 2 additions & 0 deletions cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,14 @@

# Related to async operators
enable_setup_async_task = conf.getboolean("cosmos", "enable_setup_async_task", fallback=True)
enable_teardown_async_task = conf.getboolean("cosmos", "enable_teardown_async_task", fallback=True)

AIRFLOW_IO_AVAILABLE = Version(airflow_version) >= Version("2.8.0")

# The following environment variable is populated in Astro Cloud
in_astro_cloud = os.getenv("ASTRONOMER_ENVIRONMENT") == "cloud"


try:
LINEAGE_NAMESPACE = conf.get("openlineage", "namespace")
except airflow.exceptions.AirflowConfigException:
Expand Down
8 changes: 8 additions & 0 deletions docs/configuration/cosmos-conf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,14 @@ This page lists all available Airflow configurations that affect ``astronomer-co
- Default: ``True``
- Environment Variable: ``AIRFLOW__COSMOS__ENABLE_SETUP_ASYNC_TASK``

.. _enable_teardown_async_task:

`enable_teardown_async_task`_:
(Introduced in Cosmos 1.9.0): Enables a teardown task for ``ExecutionMode.AIRFLOW_ASYNC`` to delete the SQL files from remote location (S3/GCS). You need to specify ``remote_target_path_conn_id`` and ``remote_target_path`` configuration to delete the artifact from the remote location.

- Default: ``True``
- Environment Variable: ``AIRFLOW__COSMOS__ENABLE_TEARDOWN_ASYNC_TASK``

[openlineage]
~~~~~~~~~~~~~

Expand Down
14 changes: 13 additions & 1 deletion tests/operators/_asynchronous/test_base.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, Mock, patch

import pytest

from cosmos.config import ProfileConfig
from cosmos.operators._asynchronous import TeardownAsyncOperator
from cosmos.operators._asynchronous.base import DbtRunAirflowAsyncFactoryOperator, _create_async_operator_class
from cosmos.operators._asynchronous.bigquery import DbtRunAirflowAsyncBigqueryOperator
from cosmos.operators.local import DbtRunLocalOperator
Expand Down Expand Up @@ -68,3 +69,14 @@ def test_dbt_run_airflow_async_factory_operator_init(mock_create_class, profile_

assert operator is not None
assert isinstance(operator, MockAsyncOperator)


@patch("cosmos.operators.local.DbtRunLocalOperator.build_and_run_cmd")
def test_teardown_execute(mock_build_and_run_cmd):
operator = TeardownAsyncOperator(
task_id="fake-task",
profile_config=Mock(),
project_dir="fake-dir",
)
operator.execute({})
mock_build_and_run_cmd.assert_called_once()
19 changes: 19 additions & 0 deletions tests/operators/test_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -1445,3 +1445,22 @@ def test_async_execution_without_start_task(mock_read_sql, mock_bq_execute, monk
"/tmp", {}, {"profile_type": "bigquery", "async_operator": BigQueryInsertJobOperator}
)
mock_bq_execute.assert_called_once()


@pytest.mark.skipif(not AIRFLOW_IO_AVAILABLE, reason="Airflow did not have Object Storage until the 2.8 release")
@patch("pathlib.Path.rglob")
@patch("cosmos.operators.local.AbstractDbtLocalBase._construct_dest_file_path")
@patch("airflow.io.path.ObjectStoragePath.unlink")
def test_async_execution_teardown_delete_files(mock_unlink, mock_construct_dest_file_path, mock_rglob):
mock_file = MagicMock()
mock_file.is_file.return_value = True
mock_file.__str__.return_value = "/altered_jaffle_shop/target/run/file1.sql"
mock_rglob.return_value = [mock_file]
project_dir = Path(__file__).parent.parent.parent / "dev/dags/dbt/altered_jaffle_shop"
operator = DbtRunLocalOperator(
task_id="test",
project_dir=project_dir,
profile_config=profile_config,
)
operator._handle_async_execution(project_dir, {}, {"profile_type": "bigquery", "teardown_task": True})
mock_unlink.assert_called()
Loading