From dd9e8982c6cd8a4d3e6044ebe698d061315bc8a7 Mon Sep 17 00:00:00 2001 From: Nate Robinson Date: Tue, 14 Jan 2025 19:24:55 -0500 Subject: [PATCH 1/3] Add support for timeout to BatchOperator --- .../src/airflow/providers/amazon/aws/operators/batch.py | 4 ++++ providers/tests/amazon/aws/operators/test_batch.py | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/providers/src/airflow/providers/amazon/aws/operators/batch.py b/providers/src/airflow/providers/amazon/aws/operators/batch.py index 3df00fb04c37f..ce4278fd787a9 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/batch.py +++ b/providers/src/airflow/providers/amazon/aws/operators/batch.py @@ -95,6 +95,7 @@ class BatchOperator(BaseOperator): If it is an array job, only the logs of the first task will be printed. :param awslogs_fetch_interval: The interval with which cloudwatch logs are to be fetched, 30 sec. :param poll_interval: (Deferrable mode only) Time in seconds to wait between polling. + :param boto3_timeout: Timeout configuration for SubmitJob. .. note:: Any custom waiters must return a waiter for these calls: @@ -184,6 +185,7 @@ def __init__( poll_interval: int = 30, awslogs_enabled: bool = False, awslogs_fetch_interval: timedelta = timedelta(seconds=30), + boto3_timeout: dict | None = None, **kwargs, ) -> None: BaseOperator.__init__(self, **kwargs) @@ -208,6 +210,7 @@ def __init__( self.poll_interval = poll_interval self.awslogs_enabled = awslogs_enabled self.awslogs_fetch_interval = awslogs_fetch_interval + self.boto3_timeout = boto3_timeout # params for hook self.max_retries = max_retries @@ -313,6 +316,7 @@ def submit_job(self, context: Context): "retryStrategy": self.retry_strategy, "shareIdentifier": self.share_identifier, "schedulingPriorityOverride": self.scheduling_priority_override, + "timeout": self.boto3_timeout, } try: diff --git a/providers/tests/amazon/aws/operators/test_batch.py b/providers/tests/amazon/aws/operators/test_batch.py index 0c14c256edba9..fad3eed6b6f5e 100644 --- a/providers/tests/amazon/aws/operators/test_batch.py +++ b/providers/tests/amazon/aws/operators/test_batch.py @@ -70,6 +70,7 @@ def setup_method(self, _, get_client_type_mock): aws_conn_id="airflow_test", region_name="eu-west-1", tags={}, + boto3_timeout={"attemptDurationSeconds": 3600}, ) self.client_mock = self.get_client_type_mock.return_value # We're mocking all actual AWS calls and don't need a connection. This @@ -109,6 +110,7 @@ def test_init(self): assert self.batch.hook.client == self.client_mock assert self.batch.tags == {} assert self.batch.wait_for_completion is True + assert self.batch.boto3_timeout == {"attemptDurationSeconds": 3600} self.get_client_type_mock.assert_called_once_with(region_name="eu-west-1") @@ -141,6 +143,7 @@ def test_init_defaults(self): assert issubclass(type(batch_job.hook.client), botocore.client.BaseClient) assert batch_job.tags == {} assert batch_job.wait_for_completion is True + assert batch_job.boto3_timeout is None def test_template_fields_overrides(self): assert self.batch.template_fields == ( @@ -181,6 +184,7 @@ def test_execute_without_failures(self, check_mock, wait_mock, job_description_m parameters={}, retryStrategy={"attempts": 1}, tags={}, + timeout={"attemptDurationSeconds": 3600}, ) assert self.batch.job_id == JOB_ID @@ -205,6 +209,7 @@ def test_execute_with_failures(self): parameters={}, retryStrategy={"attempts": 1}, tags={}, + timeout={"attemptDurationSeconds": 3600}, ) @mock.patch.object(BatchClientHook, "get_job_description") @@ -261,6 +266,7 @@ def test_execute_with_ecs_overrides(self, check_mock, wait_mock, job_description parameters={}, retryStrategy={"attempts": 1}, tags={}, + timeout={"attemptDurationSeconds": 3600}, ) @mock.patch.object(BatchClientHook, "get_job_description") @@ -359,6 +365,7 @@ def test_execute_with_eks_overrides(self, check_mock, wait_mock, job_description parameters={}, retryStrategy={"attempts": 1}, tags={}, + timeout={"attemptDurationSeconds": 3600}, ) @mock.patch.object(BatchClientHook, "check_job_success") From 3fcf9bd204e4d8eecc674e03d803998fa7b383b0 Mon Sep 17 00:00:00 2001 From: Nate Robinson Date: Wed, 15 Jan 2025 12:15:29 -0500 Subject: [PATCH 2/3] Rename to batch_execution_timeout --- .../airflow/providers/amazon/aws/operators/batch.py | 10 ++++++---- providers/tests/amazon/aws/operators/test_batch.py | 6 +++--- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/providers/src/airflow/providers/amazon/aws/operators/batch.py b/providers/src/airflow/providers/amazon/aws/operators/batch.py index ce4278fd787a9..b7303ea779198 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/batch.py +++ b/providers/src/airflow/providers/amazon/aws/operators/batch.py @@ -95,7 +95,7 @@ class BatchOperator(BaseOperator): If it is an array job, only the logs of the first task will be printed. :param awslogs_fetch_interval: The interval with which cloudwatch logs are to be fetched, 30 sec. :param poll_interval: (Deferrable mode only) Time in seconds to wait between polling. - :param boto3_timeout: Timeout configuration for SubmitJob. + :param batch_execution_timeout: Execution timeout in seconds for submitted batch job. .. note:: Any custom waiters must return a waiter for these calls: @@ -185,7 +185,7 @@ def __init__( poll_interval: int = 30, awslogs_enabled: bool = False, awslogs_fetch_interval: timedelta = timedelta(seconds=30), - boto3_timeout: dict | None = None, + batch_execution_timeout: int | None = None, **kwargs, ) -> None: BaseOperator.__init__(self, **kwargs) @@ -210,7 +210,7 @@ def __init__( self.poll_interval = poll_interval self.awslogs_enabled = awslogs_enabled self.awslogs_fetch_interval = awslogs_fetch_interval - self.boto3_timeout = boto3_timeout + self.batch_execution_timeout = batch_execution_timeout # params for hook self.max_retries = max_retries @@ -316,9 +316,11 @@ def submit_job(self, context: Context): "retryStrategy": self.retry_strategy, "shareIdentifier": self.share_identifier, "schedulingPriorityOverride": self.scheduling_priority_override, - "timeout": self.boto3_timeout, } + if self.batch_execution_timeout: + args["timeout"] = {"attemptDurationSeconds": self.batch_execution_timeout} + try: response = self.hook.client.submit_job(**trim_none_values(args)) except Exception as e: diff --git a/providers/tests/amazon/aws/operators/test_batch.py b/providers/tests/amazon/aws/operators/test_batch.py index fad3eed6b6f5e..8ea964c0594ed 100644 --- a/providers/tests/amazon/aws/operators/test_batch.py +++ b/providers/tests/amazon/aws/operators/test_batch.py @@ -70,7 +70,7 @@ def setup_method(self, _, get_client_type_mock): aws_conn_id="airflow_test", region_name="eu-west-1", tags={}, - boto3_timeout={"attemptDurationSeconds": 3600}, + batch_execution_timeout=3600, ) self.client_mock = self.get_client_type_mock.return_value # We're mocking all actual AWS calls and don't need a connection. This @@ -110,7 +110,7 @@ def test_init(self): assert self.batch.hook.client == self.client_mock assert self.batch.tags == {} assert self.batch.wait_for_completion is True - assert self.batch.boto3_timeout == {"attemptDurationSeconds": 3600} + assert self.batch.batch_execution_timeout == 3600 self.get_client_type_mock.assert_called_once_with(region_name="eu-west-1") @@ -143,7 +143,7 @@ def test_init_defaults(self): assert issubclass(type(batch_job.hook.client), botocore.client.BaseClient) assert batch_job.tags == {} assert batch_job.wait_for_completion is True - assert batch_job.boto3_timeout is None + assert batch_job.batch_execution_timeout is None def test_template_fields_overrides(self): assert self.batch.template_fields == ( From 2f2066b43bd8cba8c73e4ec392c67d30d626c2c5 Mon Sep 17 00:00:00 2001 From: Nate Robinson Date: Wed, 15 Jan 2025 15:15:32 -0500 Subject: [PATCH 3/3] Rename to submit_job_timeout --- .../airflow/providers/amazon/aws/operators/batch.py | 10 +++++----- providers/tests/amazon/aws/operators/test_batch.py | 6 +++--- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/providers/src/airflow/providers/amazon/aws/operators/batch.py b/providers/src/airflow/providers/amazon/aws/operators/batch.py index b7303ea779198..e69508d89319f 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/batch.py +++ b/providers/src/airflow/providers/amazon/aws/operators/batch.py @@ -95,7 +95,7 @@ class BatchOperator(BaseOperator): If it is an array job, only the logs of the first task will be printed. :param awslogs_fetch_interval: The interval with which cloudwatch logs are to be fetched, 30 sec. :param poll_interval: (Deferrable mode only) Time in seconds to wait between polling. - :param batch_execution_timeout: Execution timeout in seconds for submitted batch job. + :param submit_job_timeout: Execution timeout in seconds for submitted batch job. .. note:: Any custom waiters must return a waiter for these calls: @@ -185,7 +185,7 @@ def __init__( poll_interval: int = 30, awslogs_enabled: bool = False, awslogs_fetch_interval: timedelta = timedelta(seconds=30), - batch_execution_timeout: int | None = None, + submit_job_timeout: int | None = None, **kwargs, ) -> None: BaseOperator.__init__(self, **kwargs) @@ -210,7 +210,7 @@ def __init__( self.poll_interval = poll_interval self.awslogs_enabled = awslogs_enabled self.awslogs_fetch_interval = awslogs_fetch_interval - self.batch_execution_timeout = batch_execution_timeout + self.submit_job_timeout = submit_job_timeout # params for hook self.max_retries = max_retries @@ -318,8 +318,8 @@ def submit_job(self, context: Context): "schedulingPriorityOverride": self.scheduling_priority_override, } - if self.batch_execution_timeout: - args["timeout"] = {"attemptDurationSeconds": self.batch_execution_timeout} + if self.submit_job_timeout: + args["timeout"] = {"attemptDurationSeconds": self.submit_job_timeout} try: response = self.hook.client.submit_job(**trim_none_values(args)) diff --git a/providers/tests/amazon/aws/operators/test_batch.py b/providers/tests/amazon/aws/operators/test_batch.py index 8ea964c0594ed..c1b1d847b7d91 100644 --- a/providers/tests/amazon/aws/operators/test_batch.py +++ b/providers/tests/amazon/aws/operators/test_batch.py @@ -70,7 +70,7 @@ def setup_method(self, _, get_client_type_mock): aws_conn_id="airflow_test", region_name="eu-west-1", tags={}, - batch_execution_timeout=3600, + submit_job_timeout=3600, ) self.client_mock = self.get_client_type_mock.return_value # We're mocking all actual AWS calls and don't need a connection. This @@ -110,7 +110,7 @@ def test_init(self): assert self.batch.hook.client == self.client_mock assert self.batch.tags == {} assert self.batch.wait_for_completion is True - assert self.batch.batch_execution_timeout == 3600 + assert self.batch.submit_job_timeout == 3600 self.get_client_type_mock.assert_called_once_with(region_name="eu-west-1") @@ -143,7 +143,7 @@ def test_init_defaults(self): assert issubclass(type(batch_job.hook.client), botocore.client.BaseClient) assert batch_job.tags == {} assert batch_job.wait_for_completion is True - assert batch_job.batch_execution_timeout is None + assert batch_job.submit_job_timeout is None def test_template_fields_overrides(self): assert self.batch.template_fields == (