Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
romsharon98 committed Nov 25, 2023
1 parent 373d8a5 commit 264eb58
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
12 changes: 9 additions & 3 deletions airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class SparkKubernetesOperator(BaseOperator):
def __init__(
self,
*,
application_file: str,
application_file: str | dict,
namespace: str | None = None,
kubernetes_conn_id: str = "kubernetes_default",
api_group: str = "sparkoperator.k8s.io",
Expand Down Expand Up @@ -111,7 +111,10 @@ def _get_namespace_event_stream(self, namespace, query_kwargs=None):
raise

def execute(self, context: Context):
body = _load_body_to_dict(self.application_file)
if isinstance(self.application_file, str):
body = _load_body_to_dict(self.application_file)
else:
body = self.application_file
name = body["metadata"]["name"]
namespace = self.namespace or self.hook.get_namespace()

Expand Down Expand Up @@ -177,7 +180,10 @@ def execute(self, context: Context):
return response

def on_kill(self) -> None:
body = _load_body_to_dict(self.application_file)
if isinstance(self.application_file, str):
body = _load_body_to_dict(self.application_file)
else:
body = self.application_file
name = body["metadata"]["name"]
namespace = self.namespace or self.hook.get_namespace()
self.hook.delete_custom_object(
Expand Down
32 changes: 32 additions & 0 deletions tests/providers/cncf/kubernetes/operators/test_spark_kubernetes.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,35 @@ def test_on_kill(mock_kubernetes_hook, mock_load_body_to_dict):
namespace="default",
name="spark-app",
)


@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
def test_execute_with_application_file_dict(mock_kubernetes_hook):
op = SparkKubernetesOperator(task_id="task_id", application_file={"metadata": {"name": "spark-app"}})
mock_kubernetes_hook.return_value.get_namespace.return_value = "default"

op.execute({})

mock_kubernetes_hook.return_value.create_custom_object.assert_called_once_with(
group="sparkoperator.k8s.io",
version="v1beta2",
plural="sparkapplications",
body={"metadata": {"name": "spark-app"}},
namespace="default",
)


@patch("airflow.providers.cncf.kubernetes.operators.spark_kubernetes.KubernetesHook")
def test_on_kill_with_application_file_dict(mock_kubernetes_hook):
op = SparkKubernetesOperator(task_id="task_id", application_file={"metadata": {"name": "spark-app"}})
mock_kubernetes_hook.return_value.get_namespace.return_value = "default"

op.on_kill()

mock_kubernetes_hook.return_value.delete_custom_object.assert_called_once_with(
group="sparkoperator.k8s.io",
version="v1beta2",
plural="sparkapplications",
name="spark-app",
namespace="default",
)

0 comments on commit 264eb58

Please sign in to comment.