From d1baacc02434a5e5289dedf6e5194b7894f1ff3f Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Sun, 10 Mar 2024 13:45:40 +0800 Subject: [PATCH] Refactor Databricks Agent Phase (#2244) Signed-off-by: Future-Outlier --- flytekit/extend/backend/utils.py | 7 ++++--- .../flytekitplugins/bigquery/agent.py | 2 +- .../flytekitplugins/spark/agent.py | 20 ++++++++++++++----- plugins/flytekit-spark/tests/test_agent.py | 6 +++++- tests/flytekit/unit/extend/test_agent.py | 5 +++++ 5 files changed, 30 insertions(+), 10 deletions(-) diff --git a/flytekit/extend/backend/utils.py b/flytekit/extend/backend/utils.py index b20c9fdf66..5199536b5d 100644 --- a/flytekit/extend/backend/utils.py +++ b/flytekit/extend/backend/utils.py @@ -20,13 +20,14 @@ def convert_to_flyte_phase(state: str) -> TaskExecution.Phase: Convert the state from the agent to the phase in flyte. """ state = state.lower() - # timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate - if state in ["failed", "timeout", "timedout", "canceled"]: + if state in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]: return TaskExecution.FAILED elif state in ["done", "succeeded", "success"]: return TaskExecution.SUCCEEDED - elif state in ["running"]: + elif state in ["running", "terminating"]: return TaskExecution.RUNNING + elif state in ["pending"]: + return TaskExecution.INITIALIZING raise ValueError(f"Unrecognized state: {state}") diff --git a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py index f6b7cfd6e6..813cc1794a 100644 --- a/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py +++ b/plugins/flytekit-bigquery/flytekitplugins/bigquery/agent.py @@ -88,7 +88,7 @@ def get(self, resource_meta: BigQueryMetadata, **kwargs) -> Resource: output_location = f"bq://{dst.project}:{dst.dataset_id}.{dst.table_id}" res = TypeEngine.dict_to_literal_map(ctx, {"results": StructuredDataset(uri=output_location)}) - return Resource(phase=cur_phase, message=job.state, log_links=[log_link], outputs=res) + return Resource(phase=cur_phase, message=str(job.state), log_links=[log_link], outputs=res) def delete(self, resource_meta: BigQueryMetadata, **kwargs): client = bigquery.Client() diff --git a/plugins/flytekit-spark/flytekitplugins/spark/agent.py b/plugins/flytekit-spark/flytekitplugins/spark/agent.py index 8200263ac3..d367f3f04a 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/agent.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/agent.py @@ -81,14 +81,20 @@ async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource: raise Exception(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}") response = await resp.json() - cur_phase = TaskExecution.RUNNING + cur_phase = TaskExecution.UNDEFINED message = "" state = response.get("state") + + # The databricks job's state is determined by life_cycle_state and result_state. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate if state: - if state.get("result_state"): - cur_phase = convert_to_flyte_phase(state["result_state"]) - if state.get("state_message"): - message = state["state_message"] + life_cycle_state = state.get("life_cycle_state") + if result_state_is_available(life_cycle_state): + result_state = state.get("result_state") + cur_phase = convert_to_flyte_phase(result_state) + else: + cur_phase = convert_to_flyte_phase(life_cycle_state) + + message = state.get("state_message") job_id = response.get("job_id") databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{resource_meta.run_id}" @@ -112,4 +118,8 @@ def get_header() -> typing.Dict[str, str]: return {"Authorization": f"Bearer {token}", "content-type": "application/json"} +def result_state_is_available(life_cycle_state: str) -> bool: + return life_cycle_state == "TERMINATED" + + AgentRegistry.register(DatabricksAgent()) diff --git a/plugins/flytekit-spark/tests/test_agent.py b/plugins/flytekit-spark/tests/test_agent.py index 80f91c5c76..642755a351 100644 --- a/plugins/flytekit-spark/tests/test_agent.py +++ b/plugins/flytekit-spark/tests/test_agent.py @@ -108,7 +108,11 @@ async def test_databricks_agent(): ) mock_create_response = {"run_id": "123"} - mock_get_response = {"job_id": "1", "run_id": "123", "state": {"result_state": "SUCCESS", "state_message": "OK"}} + mock_get_response = { + "job_id": "1", + "run_id": "123", + "state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS", "state_message": "OK"}, + } mock_delete_response = {} create_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit" get_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=123" diff --git a/tests/flytekit/unit/extend/test_agent.py b/tests/flytekit/unit/extend/test_agent.py index 85c88def45..8a1289f974 100644 --- a/tests/flytekit/unit/extend/test_agent.py +++ b/tests/flytekit/unit/extend/test_agent.py @@ -288,12 +288,17 @@ def test_convert_to_flyte_phase(): assert convert_to_flyte_phase("TIMEOUT") == TaskExecution.FAILED assert convert_to_flyte_phase("TIMEDOUT") == TaskExecution.FAILED assert convert_to_flyte_phase("CANCELED") == TaskExecution.FAILED + assert convert_to_flyte_phase("SKIPPED") == TaskExecution.FAILED + assert convert_to_flyte_phase("INTERNAL_ERROR") == TaskExecution.FAILED assert convert_to_flyte_phase("DONE") == TaskExecution.SUCCEEDED assert convert_to_flyte_phase("SUCCEEDED") == TaskExecution.SUCCEEDED assert convert_to_flyte_phase("SUCCESS") == TaskExecution.SUCCEEDED assert convert_to_flyte_phase("RUNNING") == TaskExecution.RUNNING + assert convert_to_flyte_phase("TERMINATING") == TaskExecution.RUNNING + + assert convert_to_flyte_phase("PENDING") == TaskExecution.INITIALIZING invalid_state = "INVALID_STATE" with pytest.raises(Exception, match=f"Unrecognized state: {invalid_state.lower()}"):