Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix DataflowJobLink for Beam operators in deferrable mode #45023

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 29 additions & 5 deletions providers/src/airflow/providers/apache/beam/hooks/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ async def start_python_pipeline_async(
py_interpreter: str = "python3",
py_requirements: list[str] | None = None,
py_system_site_packages: bool = False,
process_line_callback: Callable[[str], None] | None = None,
):
"""
Start Apache Beam python pipeline.
Expand All @@ -470,6 +471,8 @@ async def start_python_pipeline_async(
:param py_system_site_packages: Whether to include system_site_packages in your virtualenv.
See virtualenv documentation for more information.
This option is only relevant if the ``py_requirements`` parameter is not None.
:param process_line_callback: Optional callback which can be used to process
stdout and stderr to detect job id
"""
py_options = py_options or []
if "labels" in variables:
Expand Down Expand Up @@ -518,16 +521,25 @@ async def start_python_pipeline_async(
return_code = await self.start_pipeline_async(
variables=variables,
command_prefix=command_prefix,
process_line_callback=process_line_callback,
)
return return_code

async def start_java_pipeline_async(self, variables: dict, jar: str, job_class: str | None = None):
async def start_java_pipeline_async(
self,
variables: dict,
jar: str,
job_class: str | None = None,
process_line_callback: Callable[[str], None] | None = None,
):
"""
Start Apache Beam Java pipeline.

:param variables: Variables passed to the job.
:param jar: Name of the jar for the pipeline.
:param job_class: Name of the java class for the pipeline.
:param process_line_callback: Optional callback which can be used to process
stdout and stderr to detect job id
:return: Beam command execution return code.
"""
if "labels" in variables:
Expand All @@ -537,6 +549,7 @@ async def start_java_pipeline_async(self, variables: dict, jar: str, job_class:
return_code = await self.start_pipeline_async(
variables=variables,
command_prefix=command_prefix,
process_line_callback=process_line_callback,
)
return return_code

Expand All @@ -545,6 +558,7 @@ async def start_pipeline_async(
variables: dict,
command_prefix: list[str],
working_directory: str | None = None,
process_line_callback: Callable[[str], None] | None = None,
) -> int:
cmd = [*command_prefix, f"--runner={self.runner}"]
if variables:
Expand All @@ -553,20 +567,24 @@ async def start_pipeline_async(
cmd=cmd,
working_directory=working_directory,
log=self.log,
process_line_callback=process_line_callback,
)

async def run_beam_command_async(
self,
cmd: list[str],
log: logging.Logger,
working_directory: str | None = None,
process_line_callback: Callable[[str], None] | None = None,
) -> int:
"""
Run pipeline command in subprocess.

:param cmd: Parts of the command to be run in subprocess
:param working_directory: Working directory
:param log: logger.
:param log: logger
:param process_line_callback: Optional callback which can be used to process
stdout and stderr to detect job id
"""
cmd_str_representation = " ".join(shlex.quote(c) for c in cmd)
log.info("Running command: %s", cmd_str_representation)
Expand All @@ -584,8 +602,8 @@ async def run_beam_command_async(
log.info("Start waiting for Apache Beam process to complete.")

# Creating separate threads for stdout and stderr
stdout_task = asyncio.create_task(self.read_logs(process.stdout))
stderr_task = asyncio.create_task(self.read_logs(process.stderr))
stdout_task = asyncio.create_task(self.read_logs(process.stdout, process_line_callback))
stderr_task = asyncio.create_task(self.read_logs(process.stderr, process_line_callback))

# Waiting for the both tasks to complete
await asyncio.gather(stdout_task, stderr_task)
Expand All @@ -598,10 +616,16 @@ async def run_beam_command_async(
raise AirflowException(f"Apache Beam process failed with return code {return_code}")
return return_code

async def read_logs(self, stream_reader):
async def read_logs(
self,
stream_reader,
process_line_callback: Callable[[str], None] | None = None,
):
while True:
line = await stream_reader.readline()
if not line:
break
decoded_line = line.decode().strip()
if process_line_callback:
process_line_callback(decoded_line)
self.log.info(decoded_line)
27 changes: 13 additions & 14 deletions providers/src/airflow/providers/apache/beam/operators/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,17 @@ def execute_complete(self, context: Context, event: dict[str, Any]):
self.task_id,
event["message"],
)
self.dataflow_job_id = event["dataflow_job_id"]
self.project_id = event["project_id"]
self.location = event["location"]

DataflowJobLink.persist(
self,
context,
self.project_id,
self.location,
self.dataflow_job_id,
)
return {"dataflow_job_id": self.dataflow_job_id}


Expand Down Expand Up @@ -425,13 +436,6 @@ def execute_sync(self, context: Context):

def execute_async(self, context: Context):
if self.is_dataflow and self.dataflow_hook:
DataflowJobLink.persist(
self,
context,
self.dataflow_config.project_id,
self.dataflow_config.location,
self.dataflow_job_id,
)
with self.dataflow_hook.provide_authorized_gcloud():
self.defer(
trigger=BeamPythonPipelineTrigger(
Expand All @@ -443,6 +447,8 @@ def execute_async(self, context: Context):
py_system_site_packages=self.py_system_site_packages,
runner=self.runner,
gcp_conn_id=self.gcp_conn_id,
project_id=self.dataflow_config.project_id,
location=self.dataflow_config.location,
),
method_name="execute_complete",
)
Expand Down Expand Up @@ -613,13 +619,6 @@ def execute_sync(self, context: Context):

def execute_async(self, context: Context):
if self.is_dataflow and self.dataflow_hook:
DataflowJobLink.persist(
self,
context,
self.dataflow_config.project_id,
self.dataflow_config.location,
self.dataflow_job_id,
)
with self.dataflow_hook.provide_authorized_gcloud():
self.pipeline_options["jobName"] = self.dataflow_job_name
self.defer(
Expand Down
42 changes: 38 additions & 4 deletions providers/src/airflow/providers/apache/beam/triggers/beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,15 @@
import asyncio
import contextlib
from collections.abc import AsyncIterator, Sequence
from typing import IO, Any
from typing import IO, Any, Callable

from google.cloud.dataflow_v1beta3 import ListJobsRequest

from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook
from airflow.providers.google.cloud.hooks.dataflow import AsyncDataflowHook
from airflow.providers.apache.beam.hooks.beam import BeamAsyncHook, BeamRunnerType
from airflow.providers.google.cloud.hooks.dataflow import (
AsyncDataflowHook,
process_line_and_extract_dataflow_job_id_callback,
)
from airflow.providers.google.cloud.hooks.gcs import GCSHook
from airflow.triggers.base import BaseTrigger, TriggerEvent

Expand All @@ -40,6 +43,14 @@ def _get_async_hook(*args, **kwargs) -> BeamAsyncHook:
def _get_sync_dataflow_hook(**kwargs) -> AsyncDataflowHook:
return AsyncDataflowHook(**kwargs)

def _get_dataflow_process_callback(self) -> Callable[[str], None]:
def set_current_dataflow_job_id(job_id):
self.dataflow_job_id = job_id

return process_line_and_extract_dataflow_job_id_callback(
on_new_job_id_callback=set_current_dataflow_job_id
)


class BeamPythonPipelineTrigger(BeamPipelineBaseTrigger):
"""
Expand All @@ -59,6 +70,8 @@ class BeamPythonPipelineTrigger(BeamPipelineBaseTrigger):
:param py_system_site_packages: Whether to include system_site_packages in your virtualenv.
See virtualenv documentation for more information.
This option is only relevant if the ``py_requirements`` parameter is not None.
:param project_id: Optional, the Google Cloud project ID in which to start a job.
:param location: Optional, Job location.
:param runner: Runner on which pipeline will be run. By default, "DirectRunner" is being used.
Other possible options: DataflowRunner, SparkRunner, FlinkRunner, PortableRunner.
See: :class:`~providers.apache.beam.hooks.beam.BeamRunnerType`
Expand All @@ -74,6 +87,8 @@ def __init__(
py_interpreter: str = "python3",
py_requirements: list[str] | None = None,
py_system_site_packages: bool = False,
project_id: str | None = None,
location: str | None = None,
runner: str = "DirectRunner",
gcp_conn_id: str = "google_cloud_default",
):
Expand All @@ -84,6 +99,9 @@ def __init__(
self.py_interpreter = py_interpreter
self.py_requirements = py_requirements
self.py_system_site_packages = py_system_site_packages
self.dataflow_job_id: str | None = None
self.project_id = project_id
self.location = location
self.runner = runner
self.gcp_conn_id = gcp_conn_id

Expand All @@ -98,6 +116,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
"py_interpreter": self.py_interpreter,
"py_requirements": self.py_requirements,
"py_system_site_packages": self.py_system_site_packages,
"project_id": self.project_id,
"location": self.location,
"runner": self.runner,
"gcp_conn_id": self.gcp_conn_id,
},
Expand All @@ -106,6 +126,8 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current pipeline status and yields a TriggerEvent."""
hook = self._get_async_hook(runner=self.runner)
is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()

try:
# Get the current running event loop to manage I/O operations asynchronously
loop = asyncio.get_running_loop()
Expand All @@ -130,6 +152,7 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
py_interpreter=self.py_interpreter,
py_requirements=self.py_requirements,
py_system_site_packages=self.py_system_site_packages,
process_line_callback=self._get_dataflow_process_callback() if is_dataflow else None,
)
except Exception as e:
self.log.exception("Exception occurred while checking for pipeline state")
Expand All @@ -140,6 +163,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
{
"status": "success",
"message": "Pipeline has finished SUCCESSFULLY",
"dataflow_job_id": self.dataflow_job_id,
"project_id": self.project_id,
"location": self.location,
}
)
else:
Expand Down Expand Up @@ -205,6 +231,7 @@ def __init__(
self.impersonation_chain = impersonation_chain
self.poll_sleep = poll_sleep
self.cancel_timeout = cancel_timeout
self.dataflow_job_id: str | None = None

def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize BeamJavaPipelineTrigger arguments and classpath."""
Expand All @@ -229,6 +256,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
"""Get current Java pipeline status and yields a TriggerEvent."""
hook = self._get_async_hook(runner=self.runner)
is_dataflow = self.runner.lower() == BeamRunnerType.DataflowRunner.lower()

return_code = 0
if self.check_if_running:
Expand Down Expand Up @@ -271,7 +299,10 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
self.jar = tmp_gcs_file.name

return_code = await hook.start_java_pipeline_async(
variables=self.variables, jar=self.jar, job_class=self.job_class
variables=self.variables,
jar=self.jar,
job_class=self.job_class,
process_line_callback=self._get_dataflow_process_callback() if is_dataflow else None,
)
except Exception as e:
self.log.exception("Exception occurred while starting the Java pipeline")
Expand All @@ -282,6 +313,9 @@ async def run(self) -> AsyncIterator[TriggerEvent]: # type: ignore[override]
{
"status": "success",
"message": "Pipeline has finished SUCCESSFULLY",
"dataflow_job_id": self.dataflow_job_id,
"project_id": self.project_id,
"location": self.location,
}
)
else:
Expand Down
12 changes: 10 additions & 2 deletions providers/tests/apache/beam/hooks/test_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,10 @@ async def test_start_pipline_async(self, mock_runner):
)

mock_runner.assert_called_once_with(
cmd=expected_cmd, working_directory=WORKING_DIRECTORY, log=hook.log
cmd=expected_cmd,
working_directory=WORKING_DIRECTORY,
log=hook.log,
process_line_callback=None,
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -516,6 +519,7 @@ async def test_start_python_pipeline(self, mock_create_dir, mock_runner, mocked_
cmd=expected_cmd,
working_directory=None,
log=ANY,
process_line_callback=None,
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -580,6 +584,7 @@ async def test_start_python_pipeline_with_custom_interpreter(
cmd=expected_cmd,
working_directory=None,
log=ANY,
process_line_callback=None,
)

@pytest.mark.asyncio
Expand Down Expand Up @@ -630,6 +635,7 @@ async def test_start_python_pipeline_with_non_empty_py_requirements_and_without_
cmd=expected_cmd,
working_directory=None,
log=ANY,
process_line_callback=None,
)
mock_virtualenv.assert_called_once_with(
venv_directory=mock.ANY,
Expand Down Expand Up @@ -671,5 +677,7 @@ async def test_start_java_pipeline_async(self, mock_start_pipeline, job_class, c
await hook.start_java_pipeline_async(variables=variables, jar=JAR_FILE, job_class=job_class)

mock_start_pipeline.assert_called_once_with(
variables=BEAM_VARIABLES_JAVA_STRING_LABELS, command_prefix=command_prefix
variables=BEAM_VARIABLES_JAVA_STRING_LABELS,
command_prefix=command_prefix,
process_line_callback=None,
)
Loading