diff --git a/providers/src/airflow/providers/amazon/aws/operators/batch.py b/providers/src/airflow/providers/amazon/aws/operators/batch.py index 3df00fb04c37f..e69508d89319f 100644 --- a/providers/src/airflow/providers/amazon/aws/operators/batch.py +++ b/providers/src/airflow/providers/amazon/aws/operators/batch.py @@ -95,6 +95,7 @@ class BatchOperator(BaseOperator): If it is an array job, only the logs of the first task will be printed. :param awslogs_fetch_interval: The interval with which cloudwatch logs are to be fetched, 30 sec. :param poll_interval: (Deferrable mode only) Time in seconds to wait between polling. + :param submit_job_timeout: Execution timeout in seconds for submitted batch job. .. note:: Any custom waiters must return a waiter for these calls: @@ -184,6 +185,7 @@ def __init__( poll_interval: int = 30, awslogs_enabled: bool = False, awslogs_fetch_interval: timedelta = timedelta(seconds=30), + submit_job_timeout: int | None = None, **kwargs, ) -> None: BaseOperator.__init__(self, **kwargs) @@ -208,6 +210,7 @@ def __init__( self.poll_interval = poll_interval self.awslogs_enabled = awslogs_enabled self.awslogs_fetch_interval = awslogs_fetch_interval + self.submit_job_timeout = submit_job_timeout # params for hook self.max_retries = max_retries @@ -315,6 +318,9 @@ def submit_job(self, context: Context): "schedulingPriorityOverride": self.scheduling_priority_override, } + if self.submit_job_timeout: + args["timeout"] = {"attemptDurationSeconds": self.submit_job_timeout} + try: response = self.hook.client.submit_job(**trim_none_values(args)) except Exception as e: diff --git a/providers/tests/amazon/aws/operators/test_batch.py b/providers/tests/amazon/aws/operators/test_batch.py index 0c14c256edba9..c1b1d847b7d91 100644 --- a/providers/tests/amazon/aws/operators/test_batch.py +++ b/providers/tests/amazon/aws/operators/test_batch.py @@ -70,6 +70,7 @@ def setup_method(self, _, get_client_type_mock): aws_conn_id="airflow_test", region_name="eu-west-1", tags={}, + submit_job_timeout=3600, ) self.client_mock = self.get_client_type_mock.return_value # We're mocking all actual AWS calls and don't need a connection. This @@ -109,6 +110,7 @@ def test_init(self): assert self.batch.hook.client == self.client_mock assert self.batch.tags == {} assert self.batch.wait_for_completion is True + assert self.batch.submit_job_timeout == 3600 self.get_client_type_mock.assert_called_once_with(region_name="eu-west-1") @@ -141,6 +143,7 @@ def test_init_defaults(self): assert issubclass(type(batch_job.hook.client), botocore.client.BaseClient) assert batch_job.tags == {} assert batch_job.wait_for_completion is True + assert batch_job.submit_job_timeout is None def test_template_fields_overrides(self): assert self.batch.template_fields == ( @@ -181,6 +184,7 @@ def test_execute_without_failures(self, check_mock, wait_mock, job_description_m parameters={}, retryStrategy={"attempts": 1}, tags={}, + timeout={"attemptDurationSeconds": 3600}, ) assert self.batch.job_id == JOB_ID @@ -205,6 +209,7 @@ def test_execute_with_failures(self): parameters={}, retryStrategy={"attempts": 1}, tags={}, + timeout={"attemptDurationSeconds": 3600}, ) @mock.patch.object(BatchClientHook, "get_job_description") @@ -261,6 +266,7 @@ def test_execute_with_ecs_overrides(self, check_mock, wait_mock, job_description parameters={}, retryStrategy={"attempts": 1}, tags={}, + timeout={"attemptDurationSeconds": 3600}, ) @mock.patch.object(BatchClientHook, "get_job_description") @@ -359,6 +365,7 @@ def test_execute_with_eks_overrides(self, check_mock, wait_mock, job_description parameters={}, retryStrategy={"attempts": 1}, tags={}, + timeout={"attemptDurationSeconds": 3600}, ) @mock.patch.object(BatchClientHook, "check_job_success")