Skip to content

Commit

Permalink
Add setup task for async executions (#1518)
Browse files Browse the repository at this point in the history
This PR reintroduces the setup task for ExecutionMode.AIRFLOW_ASYNC,
This will prevent the execution of the dbt command for each Airflow task
node
for run operator and will only run the dbt command mock version for
start job node.

Additionally, a new configuration, `enable_setup_task`, has been
introduced to
enable or disable this feature.

**With enable_setup_task**

<img width="1686" alt="Screenshot 2025-02-08 at 9 53 35 PM"
src="https://github.com/user-attachments/assets/cb9daf76-2bd6-4bcf-8219-b64562c4151a"
/>


**Without enable_setup_task**

<img width="1668" alt="Screenshot 2025-02-08 at 10 23 49 PM"
src="https://github.com/user-attachments/assets/4ac99c3b-15df-484e-b9e4-cc7b1022ea31"
/>



Related-to: #1477
  • Loading branch information
pankajastro authored Feb 11, 2025
1 parent 0fd61c3 commit 577c08e
Show file tree
Hide file tree
Showing 12 changed files with 168 additions and 71 deletions.
2 changes: 1 addition & 1 deletion cosmos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
Contains dags, task groups, and operators.
"""

__version__ = "1.9.0a5"
__version__ = "1.9.0a6"


from cosmos.airflow.dag import DbtDag
Expand Down
36 changes: 35 additions & 1 deletion cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from cosmos.config import RenderConfig
from cosmos.constants import (
DBT_SETUP_ASYNC_TASK_ID,
DEFAULT_DBT_RESOURCES,
SUPPORTED_BUILD_RESOURCES,
TESTABLE_DBT_RESOURCES,
Expand All @@ -24,6 +25,7 @@
from cosmos.core.graph.entities import Task as TaskMetadata
from cosmos.dbt.graph import DbtNode
from cosmos.log import get_logger
from cosmos.settings import enable_setup_async_task

logger = get_logger(__name__)

Expand Down Expand Up @@ -138,7 +140,6 @@ def create_test_task_metadata(
task_args["on_warning_callback"] = on_warning_callback
extra_context = {}
detached_from_parent = detached_from_parent or {}

task_owner = ""

if test_indirect_selection != TestIndirectSelection.EAGER:
Expand Down Expand Up @@ -404,6 +405,37 @@ def _get_dbt_dag_task_group_identifier(dag: DAG, task_group: TaskGroup | None) -
return dag_task_group_identifier


def _add_dbt_setup_async_task(
dag: DAG,
execution_mode: ExecutionMode,
task_args: dict[str, Any],
tasks_map: dict[str, Any],
task_group: TaskGroup | None,
render_config: RenderConfig | None = None,
) -> None:
if execution_mode != ExecutionMode.AIRFLOW_ASYNC:
return

if render_config is not None:
task_args["select"] = render_config.select
task_args["selector"] = render_config.selector
task_args["exclude"] = render_config.exclude

setup_task_metadata = TaskMetadata(
id=DBT_SETUP_ASYNC_TASK_ID,
operator_class="cosmos.operators._asynchronous.SetupAsyncOperator",
arguments=task_args,
extra_context={"dbt_dag_task_group_identifier": _get_dbt_dag_task_group_identifier(dag, task_group)},
)
setup_airflow_task = create_airflow_task(setup_task_metadata, dag, task_group=task_group)

for task_id, task in tasks_map.items():
if not task.upstream_list:
setup_airflow_task >> task

tasks_map[DBT_SETUP_ASYNC_TASK_ID] = setup_airflow_task


def should_create_detached_nodes(render_config: RenderConfig) -> bool:
"""
Decide if we should calculate / insert detached nodes into the graph.
Expand Down Expand Up @@ -561,6 +593,8 @@ def build_airflow_graph(
tasks_map[node_id] = test_task

create_airflow_task_dependencies(nodes, tasks_map)
if enable_setup_async_task:
_add_dbt_setup_async_task(dag, execution_mode, task_args, tasks_map, task_group, render_config=render_config)
return tasks_map


Expand Down
2 changes: 1 addition & 1 deletion cosmos/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _missing_value_(cls, value): # type: ignore
# https://docs.getdbt.com/reference/commands/test
TESTABLE_DBT_RESOURCES = {DbtResourceType.MODEL, DbtResourceType.SOURCE, DbtResourceType.SNAPSHOT, DbtResourceType.SEED}

DBT_COMPILE_TASK_ID = "dbt_compile"
DBT_SETUP_ASYNC_TASK_ID = "dbt_setup_async"

TELEMETRY_URL = "https://astronomer.gateway.scarf.sh/astronomer-cosmos/{telemetry_version}/{cosmos_version}/{airflow_version}/{python_version}/{platform_system}/{platform_machine}/{event_type}/{status}/{dag_hash}/{task_count}/{cosmos_task_count}"
TELEMETRY_VERSION = "v1"
Expand Down
19 changes: 19 additions & 0 deletions cosmos/operators/_asynchronous/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from typing import Any

from airflow.utils.context import Context

from cosmos.operators.local import DbtRunLocalOperator as DbtRunOperator


class SetupAsyncOperator(DbtRunOperator):
def __init__(self, *args: Any, **kwargs: Any):
kwargs["emit_datasets"] = False
super().__init__(*args, **kwargs)

def execute(self, context: Context, **kwargs: Any) -> None:
async_context = {"profile_type": self.profile_config.get_profile_type()}
self.build_and_run_cmd(
context=context, cmd_flags=self.dbt_cmd_flags, run_as_async=True, async_context=async_context
)
36 changes: 34 additions & 2 deletions cosmos/operators/_asynchronous/bigquery.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import Any, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, Sequence

import airflow
from airflow.providers.google.cloud.operators.bigquery import BigQueryInsertJobOperator
Expand All @@ -12,6 +13,7 @@
from cosmos.dataset import get_dataset_alias_name
from cosmos.exceptions import CosmosValueError
from cosmos.operators.local import AbstractDbtLocalBase
from cosmos.settings import enable_setup_async_task, remote_target_path, remote_target_path_conn_id

AIRFLOW_VERSION = Version(airflow.__version__)

Expand Down Expand Up @@ -97,5 +99,35 @@ def __init__(
def base_cmd(self) -> list[str]:
return ["run"]

def get_remote_sql(self) -> str:
if not settings.AIRFLOW_IO_AVAILABLE: # pragma: no cover
raise CosmosValueError(f"Cosmos async support is only available starting in Airflow 2.8 or later.")
from airflow.io.path import ObjectStoragePath

file_path = self.async_context["dbt_node_config"]["file_path"] # type: ignore
dbt_dag_task_group_identifier = self.async_context["dbt_dag_task_group_identifier"]

remote_target_path_str = str(remote_target_path).rstrip("/")

if TYPE_CHECKING: # pragma: no cover
assert self.project_dir is not None

project_dir_parent = str(Path(self.project_dir).parent)
relative_file_path = str(file_path).replace(project_dir_parent, "").lstrip("/")
remote_model_path = f"{remote_target_path_str}/{dbt_dag_task_group_identifier}/run/{relative_file_path}"

object_storage_path = ObjectStoragePath(remote_model_path, conn_id=remote_target_path_conn_id)
with object_storage_path.open() as fp: # type: ignore
return fp.read() # type: ignore

def execute(self, context: Context, **kwargs: Any) -> None:
self.build_and_run_cmd(context=context, run_as_async=True, async_context=self.async_context)
if enable_setup_async_task:
self.configuration = {
"query": {
"query": self.get_remote_sql(),
"useLegacySql": False,
}
}
super().execute(context=context)
else:
self.build_and_run_cmd(context=context, run_as_async=True, async_context=self.async_context)
49 changes: 20 additions & 29 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from cosmos.dataset import get_dataset_alias_name
from cosmos.dbt.project import get_partial_parse_path, has_non_empty_dependencies_file
from cosmos.exceptions import AirflowCompatibilityError, CosmosDbtRunError, CosmosValueError
from cosmos.settings import remote_target_path, remote_target_path_conn_id
from cosmos.settings import enable_setup_async_task, remote_target_path, remote_target_path_conn_id

try:
from airflow.datasets import Dataset
Expand Down Expand Up @@ -305,10 +305,7 @@ def _configure_remote_target_path() -> tuple[Path, str] | tuple[None, None]:
return _configured_target_path, remote_conn_id

def _construct_dest_file_path(
self,
dest_target_dir: Path,
file_path: str,
source_compiled_dir: Path,
self, dest_target_dir: Path, file_path: str, source_compiled_dir: Path, resource_type: str
) -> str:
"""
Construct the destination path for the compiled SQL files to be uploaded to the remote store.
Expand All @@ -317,28 +314,20 @@ def _construct_dest_file_path(
dag_task_group_identifier = self.extra_context["dbt_dag_task_group_identifier"]
rel_path = os.path.relpath(file_path, source_compiled_dir).lstrip("/")

return f"{dest_target_dir_str}/{dag_task_group_identifier}/compiled/{rel_path}"

def upload_compiled_sql(self, tmp_project_dir: str, context: Context) -> None:
"""
Uploads the compiled SQL files from the dbt compile output to the remote store.
"""
if not self.should_upload_compiled_sql:
return
return f"{dest_target_dir_str}/{dag_task_group_identifier}/{resource_type}/{rel_path}"

def _upload_sql_files(self, tmp_project_dir: str, resource_type: str) -> None:
dest_target_dir, dest_conn_id = self._configure_remote_target_path()

if not dest_target_dir:
raise CosmosValueError(
"You're trying to upload compiled SQL files, but the remote target path is not configured. "
)
raise CosmosValueError("You're trying to upload SQL files, but the remote target path is not configured. ")

from airflow.io.path import ObjectStoragePath

source_compiled_dir = Path(tmp_project_dir) / "target" / "compiled"
files = [str(file) for file in source_compiled_dir.rglob("*") if file.is_file()]
source_run_dir = Path(tmp_project_dir) / f"target/{resource_type}"
files = [str(file) for file in source_run_dir.rglob("*") if file.is_file()]
for file_path in files:
dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_compiled_dir)
dest_file_path = self._construct_dest_file_path(dest_target_dir, file_path, source_run_dir, resource_type)
dest_object_storage_path = ObjectStoragePath(dest_file_path, conn_id=dest_conn_id)
ObjectStoragePath(file_path).copy(dest_object_storage_path)
self.log.debug("Copied %s to %s", file_path, dest_object_storage_path)
Expand Down Expand Up @@ -439,8 +428,6 @@ def _install_dependencies(
def _mock_dbt_adapter(async_context: dict[str, Any] | None) -> None:
if not async_context:
raise CosmosValueError("`async_context` is necessary for running the model asynchronously")
if "async_operator" not in async_context:
raise CosmosValueError("`async_operator` needs to be specified in `async_context` when running as async")
if "profile_type" not in async_context:
raise CosmosValueError("`profile_type` needs to be specified in `async_context` when running as async")
profile_type = async_context["profile_type"]
Expand All @@ -466,19 +453,23 @@ def _update_partial_parse_cache(self, tmp_dir_path: Path) -> None:
def _handle_post_execution(self, tmp_project_dir: str, context: Context) -> None:
self.store_freshness_json(tmp_project_dir, context)
self.store_compiled_sql(tmp_project_dir, context)
self.upload_compiled_sql(tmp_project_dir, context)
if self.should_upload_compiled_sql:
self._upload_sql_files(tmp_project_dir, "compiled")
if self.callback:
self.callback_args.update({"context": context})
self.callback(tmp_project_dir, **self.callback_args)

def _handle_async_execution(self, tmp_project_dir: str, context: Context, async_context: dict[str, Any]) -> None:
sql = self._read_run_sql_from_target_dir(tmp_project_dir, async_context)
profile_type = async_context["profile_type"]
module_path = f"cosmos.operators._asynchronous.{profile_type}"
method_name = f"_configure_{profile_type}_async_op_args"
async_op_configurator = load_method_from_module(module_path, method_name)
async_op_configurator(self, sql=sql)
async_context["async_operator"].execute(self, context)
if enable_setup_async_task:
self._upload_sql_files(tmp_project_dir, "run")
else:
sql = self._read_run_sql_from_target_dir(tmp_project_dir, async_context)
profile_type = async_context["profile_type"]
module_path = f"cosmos.operators._asynchronous.{profile_type}"
method_name = f"_configure_{profile_type}_async_op_args"
async_op_configurator = load_method_from_module(module_path, method_name)
async_op_configurator(self, sql=sql)
async_context["async_operator"].execute(self, context)

def run_command(
self,
Expand Down
3 changes: 3 additions & 0 deletions cosmos/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@
remote_target_path = conf.get("cosmos", "remote_target_path", fallback=None)
remote_target_path_conn_id = conf.get("cosmos", "remote_target_path_conn_id", fallback=None)

# Related to async operators
enable_setup_async_task = conf.getboolean("cosmos", "enable_setup_async_task", fallback=True)

AIRFLOW_IO_AVAILABLE = Version(airflow_version) >= Version("2.8.0")

# The following environment variable is populated in Astro Cloud
Expand Down
8 changes: 8 additions & 0 deletions docs/configuration/cosmos-conf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,14 @@ This page lists all available Airflow configurations that affect ``astronomer-co
- Default: ``None``
- Environment Variable: ``AIRFLOW__COSMOS__REMOTE_TARGET_PATH_CONN_ID``

.. _enable_setup_async_task:

`enable_setup_async_task`_:
(Introduced in Cosmos 1.9.0): Enables a setup task for ``ExecutionMode.AIRFLOW_ASYNC`` to generate SQL files and upload them to a remote location (S3/GCS), preventing the ``run`` command from being executed on every node. You need to specify ``remote_target_path_conn_id`` and ``remote_target_path`` configuration to upload the artifact to the remote location.

- Default: ``True``
- Environment Variable: ``AIRFLOW__COSMOS__ENABLE_SETUP_ASYNC_TASK``

[openlineage]
~~~~~~~~~~~~~

Expand Down
8 changes: 3 additions & 5 deletions docs/getting_started/execution-modes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,11 @@ This execution mode could be preferred when you've long running resources and yo
leveraging Airflow's deferrable operators. With that, you would be able to potentially observe higher throughput of tasks
as more dbt nodes will be run in parallel since they won't be blocking Airflow's worker slots.

In this mode, Cosmos adds a new operator, ``DbtCompileAirflowAsyncOperator``, as a root task in the DbtDag or DbtTaskGroup. The task runs
the ``dbt compile`` command on your dbt project which then outputs compiled SQLs in the project's target directory.
In this mode, Cosmos adds a new operator, ``SetupAsyncOperator``, as a root task in the DbtDag or DbtTaskGroup. The task runs
the mocked ``dbt run`` command on your dbt project which then outputs compiled SQLs in the project's target directory.
As part of the same task run, these compiled SQLs are then stored remotely to a remote path set using the
:ref:`remote_target_path` configuration. The remote path is then used by the subsequent tasks in the DAG to
fetch (from the remote path) and run the compiled SQLs asynchronously using e.g. the ``DbtRunAirflowAsyncOperator``.
fetch (from the remote path) and run the compiled SQLs asynchronously using e.g. the ``SetupAsyncOperator``.
You may observe that the compile task takes a bit longer to run due to the latency of storing the compiled SQLs
remotely (e.g. for the classic ``jaffle_shop`` dbt project, upon compiling it produces about 31 files measuring about 124KB in total, but on a local
machine it took approximately 25 seconds for the task to compile & upload the compiled SQLs to the remote path).,
Expand All @@ -312,9 +312,7 @@ Note that currently, the ``airflow_async`` execution mode has the following limi
2. **Limited to dbt models**: Only dbt resource type models are run asynchronously using Airflow deferrable operators. Other resource types are executed synchronously, similar to the local execution mode.
3. **BigQuery support only**: This mode only supports BigQuery as the target database. If a different target is specified, Cosmos will throw an error indicating the target database is unsupported in this mode.
4. **ProfileMapping parameter required**: You need to specify the ``ProfileMapping`` parameter in the ``ProfileConfig`` for your DAG. Refer to the example DAG below for details on setting this parameter.
5. **Supports only full_refresh models**: Currently, only ``full_refresh`` models are supported. To enable this, pass ``full_refresh=True`` in the ``operator_args`` of the ``DbtDag`` or ``DbtTaskGroup``. Refer to the example DAG below for details on setting this parameter.
6. **location parameter required**: You must specify the location of the BigQuery dataset in the ``operator_args`` of the ``DbtDag`` or ``DbtTaskGroup``. The example DAG below provides guidance on this.
7. **No dataset emission**: The async run operators do not currently emit datasets, meaning that :ref:`data-aware-scheduling` is not supported at this time. Future releases will address this limitation.

To start leveraging async execution mode that is currently supported for the BigQuery profile type targets you need to install Cosmos with the below additional dependencies:

Expand Down
3 changes: 2 additions & 1 deletion tests/operators/_asynchronous/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,9 @@ def test_dbt_run_airflow_async_bigquery_operator_base_cmd(profile_config_mock):


@patch.object(DbtRunAirflowAsyncBigqueryOperator, "build_and_run_cmd")
def test_dbt_run_airflow_async_bigquery_operator_execute(mock_build_and_run_cmd, profile_config_mock):
def test_dbt_run_airflow_async_bigquery_operator_execute(mock_build_and_run_cmd, profile_config_mock, monkeypatch):
"""Test execute calls build_and_run_cmd with correct parameters."""
monkeypatch.setattr("cosmos.operators._asynchronous.bigquery.enable_setup_async_task", False)
operator = DbtRunAirflowAsyncBigqueryOperator(
task_id="test_task",
project_dir="/path/to/project",
Expand Down
Loading

0 comments on commit 577c08e

Please sign in to comment.