Skip to content

Commit

Permalink
Fix: ensure operator execute method is consistent across all execut…
Browse files Browse the repository at this point in the history
…ion base subclasses (#805)

This fixes an issue reported in #804 after the refactor done in
#774 where the
`execute` methods for `DbtLocalBaseOperator`, `DbtDockerBaseOperator`,
and `DbtKubernetesBaseOperator` were different.

This PR refactors the `execute` method to the `AbstractDbtBaseOperator`
so it's the same for all of the local, docker and kubernetes inherited
operators, and adds `build_and_run_cmd` as an abstract method since the
implementation is different across the 3 different execution modes.

Closes #804
  • Loading branch information
jbandoro authored Jan 24, 2024
1 parent ef2c7bb commit 9c090a4
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 10 deletions.
7 changes: 7 additions & 0 deletions cosmos/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,13 @@ def build_cmd(

return dbt_cmd, env

@abstractmethod
def build_and_run_cmd(self, context: Context, cmd_flags: list[str]) -> Any:
"""Override this method for the operator to execute the dbt command"""

def execute(self, context: Context) -> Any | None: # type: ignore
self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())


class DbtBuildMixin:
"""Mixin for dbt build command."""
Expand Down
3 changes: 0 additions & 3 deletions cosmos/operators/docker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ def build_command(self, context: Context, cmd_flags: list[str] | None = None) ->
self.environment: dict[str, Any] = {**env_vars, **self.environment}
self.command: list[str] = dbt_cmd

def execute(self, context: Context) -> None:
self.build_and_run_cmd(context=context)


class DbtBuildDockerOperator(DbtBuildMixin, DbtDockerBaseOperator):
"""
Expand Down
3 changes: 0 additions & 3 deletions cosmos/operators/kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,6 @@ def build_kube_args(self, context: Context, cmd_flags: list[str] | None = None)
self.build_env_args(env_vars)
self.arguments = dbt_cmd

def execute(self, context: Context) -> None:
self.build_and_run_cmd(context=context)


class DbtBuildKubernetesOperator(DbtBuildMixin, DbtKubernetesBaseOperator):
"""
Expand Down
3 changes: 0 additions & 3 deletions cosmos/operators/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,6 @@ def build_and_run_cmd(self, context: Context, cmd_flags: list[str] | None = None
logger.info(result.output)
return result

def execute(self, context: Context) -> None:
self.build_and_run_cmd(context=context, cmd_flags=self.add_cmd_flags())

def on_kill(self) -> None:
if self.cancel_query_on_kill:
self.subprocess_hook.log.info("Sending SIGINT signal to process group")
Expand Down
18 changes: 17 additions & 1 deletion tests/operators/test_base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from unittest.mock import patch

from cosmos.operators.base import (
AbstractDbtBaseOperator,
Expand All @@ -14,11 +15,26 @@

def test_dbt_base_operator_is_abstract():
"""Tests that the abstract base operator cannot be instantiated since the base_cmd is not defined."""
expected_error = "Can't instantiate abstract class AbstractDbtBaseOperator with abstract methods? base_cmd"
expected_error = (
"Can't instantiate abstract class AbstractDbtBaseOperator with abstract methods base_cmd, build_and_run_cmd"
)
with pytest.raises(TypeError, match=expected_error):
AbstractDbtBaseOperator()


@pytest.mark.parametrize("cmd_flags", [["--some-flag"], []])
@patch("cosmos.operators.base.AbstractDbtBaseOperator.build_and_run_cmd")
def test_dbt_base_operator_execute(mock_build_and_run_cmd, cmd_flags, monkeypatch):
"""Tests that the base operator execute method calls the build_and_run_cmd method with the expected arguments."""
monkeypatch.setattr(AbstractDbtBaseOperator, "add_cmd_flags", lambda _: cmd_flags)
AbstractDbtBaseOperator.__abstractmethods__ = set()

base_operator = AbstractDbtBaseOperator(task_id="fake_task", project_dir="fake_dir")

base_operator.execute(context={})
mock_build_and_run_cmd.assert_called_once_with(context={}, cmd_flags=cmd_flags)


@pytest.mark.parametrize(
"dbt_command, dbt_operator_class",
[
Expand Down

0 comments on commit 9c090a4

Please sign in to comment.