Skip to content

Commit

Permalink
Add teardown task
Browse files Browse the repository at this point in the history
  • Loading branch information
pankajastro committed Feb 10, 2025
1 parent d4f2af3 commit d04da59
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 13 deletions.
4 changes: 3 additions & 1 deletion cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from cosmos.core.graph.entities import Task as TaskMetadata
from cosmos.dbt.graph import DbtNode
from cosmos.log import get_logger
from cosmos.settings import enable_setup_task

logger = get_logger(__name__)

Expand Down Expand Up @@ -592,7 +593,8 @@ def build_airflow_graph(
tasks_map[node_id] = test_task

create_airflow_task_dependencies(nodes, tasks_map)
_add_teardown_task(dag, execution_mode, task_args, tasks_map, task_group, render_config=render_config)
if enable_setup_task:
_add_teardown_task(dag, execution_mode, task_args, tasks_map, task_group, render_config=render_config)
return tasks_map


Expand Down
5 changes: 4 additions & 1 deletion cosmos/operators/_asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)

def execute(self, context: Context, **kwargs: Any) -> Any:
pass
async_context = {"profile_type": self.profile_config.get_profile_type(), "teardown_task": True}
self.build_and_run_cmd(
context=context, cmd_flags=self.dbt_cmd_flags, run_as_async=True, async_context=async_context
)
38 changes: 27 additions & 11 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def _construct_dest_file_path(
dest_target_dir: Path,
file_path: str,
source_compiled_dir: Path,
resource_type: str,
) -> str:
"""
Construct the destination path for the compiled SQL files to be uploaded to the remote store.
Expand All @@ -317,7 +318,7 @@ def _construct_dest_file_path(
dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"]
rel_path = os.path.relpath(file_path, source_compiled_dir).lstrip("/")

return f"{dest_target_dir_str}/{dag_task_group_identifier}/compiled/{rel_path}"
return f"{dest_target_dir_str}/{dag_task_group_identifier}/{resource_type}/{rel_path}"

def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
"""
Expand All @@ -338,7 +339,7 @@ def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
source_compiled_dir = Path(tmp_project_dir) / "target" / "compiled"
files = [str(file) for file in source_compiled_dir.rglob("*") if file.is_file()]
for file_path in files:
dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_compiled_dir)
dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_compiled_dir, "compiled")
dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id)
ObjectStoragePath(file_path).copy(dest_object_storage_path)
self.log.debug("Copied %s to %s", file_path, dest_object_storage_path)
Expand Down Expand Up @@ -439,8 +440,8 @@ def _install_dependencies(
def _mock_dbt_adapter(async_context: dict[str, Any] | None) -> None:
if not async_context:
raise CosmosValueError("`async_context` is necessary for running the model asynchronously")
if "async_operator" not in async_context:
raise CosmosValueError("`async_operator` needs to be specified in `async_context` when running as async")
# if "async_operator" not in async_context:
# raise CosmosValueError("`async_operator` needs to be specified in `async_context` when running as async")
if "profile_type" not in async_context:
raise CosmosValueError("`profile_type` needs to be specified in `async_context` when running as async")
profile_type = async_context["profile_type"]
Expand Down Expand Up @@ -471,14 +472,29 @@ def _handle_post_execution(self, tmp_project_dir: str, context: Context) -> None
self.callback_args.update({"context": context})
self.callback(tmp_project_dir, **self.callback_args)

def _delete_sql_files(self, tmp_project_dir: Path, resource_type: str) -> None:
dest_target_dir, dest_conn_id = self._configure_remote_target_path()
source_run_dir = Path(tmp_project_dir) / f"target/{resource_type}"
files = [str(file) for file in source_run_dir.rglob("*") if file.is_file()]
from airflow.io.path import ObjectStoragePath

for file_path in files:
dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_run_dir, resource_type) # type: ignore
dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id)
dest_object_storage_path.unlink()
self.log.info("Copied %s to %s", file_path, dest_object_storage_path)

def _handle_async_execution(self, tmp_project_dir: str, context: Context, async_context: dict[str, Any]) -> None:
sql = self._read_run_sql_from_target_dir(tmp_project_dir, async_context)
profile_type = async_context["profile_type"]
module_path = f"cosmos.operators._asynchronous.{profile_type}"
method_name = f"_configure_{profile_type}_async_op_args"
async_op_configurator = load_method_from_module(module_path, method_name)
async_op_configurator(self, sql=sql)
async_context["async_operator"].execute(self, context)
if async_context.get("teardown_task"):
self._delete_sql_files(Path(tmp_project_dir), "run")
else:
sql = self._read_run_sql_from_target_dir(tmp_project_dir, async_context)
profile_type = async_context["profile_type"]
module_path = f"cosmos.operators._asynchronous.{profile_type}"
method_name = f"_configure_{profile_type}_async_op_args"
async_op_configurator = load_method_from_module(module_path, method_name)
async_op_configurator(self, sql=sql)
async_context["async_operator"].execute(self, context)

def run_command(
self,
Expand Down
3 changes: 3 additions & 0 deletions cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
# The following environment variable is populated in Astro Cloud
in_astro_cloud = os.getenv("ASTRONOMER_ENVIRONMENT") == "cloud"

# Related to async operators
enable_setup_task = conf.getboolean("cosmos", "enable_setup_task", fallback=True)

try:
LINEAGE_NAMESPACE = conf.get("openlineage", "namespace")
except airflow.exceptions.AirflowConfigException:
Expand Down

0 comments on commit d04da59

Please sign in to comment.