Skip to content

Commit

Permalink
Refactor Databricks Agent Phase (#2244)
Browse files Browse the repository at this point in the history
Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier authored Mar 10, 2024
1 parent 92bf7a7 commit d1baacc
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 10 deletions.
7 changes: 4 additions & 3 deletions flytekit/extend/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@ def convert_to_flyte_phase(state: str) -> TaskExecution.Phase:
Convert the state from the agent to the phase in flyte.
"""
state = state.lower()
# timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate
if state in ["failed", "timeout", "timedout", "canceled"]:
if state in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]:
return TaskExecution.FAILED
elif state in ["done", "succeeded", "success"]:
return TaskExecution.SUCCEEDED
elif state in ["running"]:
elif state in ["running", "terminating"]:
return TaskExecution.RUNNING
elif state in ["pending"]:
return TaskExecution.INITIALIZING
raise ValueError(f"Unrecognized state: {state}")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def get(self, resource_meta: BigQueryMetadata, **kwargs) -> Resource:
output_location = f"bq://{dst.project}:{dst.dataset_id}.{dst.table_id}"
res = TypeEngine.dict_to_literal_map(ctx, {"results": StructuredDataset(uri=output_location)})

return Resource(phase=cur_phase, message=job.state, log_links=[log_link], outputs=res)
return Resource(phase=cur_phase, message=str(job.state), log_links=[log_link], outputs=res)

def delete(self, resource_meta: BigQueryMetadata, **kwargs):
client = bigquery.Client()
Expand Down
20 changes: 15 additions & 5 deletions plugins/flytekit-spark/flytekitplugins/spark/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,20 @@ async def get(self, resource_meta: DatabricksJobMetadata, **kwargs) -> Resource:
raise Exception(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}")
response = await resp.json()

cur_phase = TaskExecution.RUNNING
cur_phase = TaskExecution.UNDEFINED
message = ""
state = response.get("state")

# The databricks job's state is determined by life_cycle_state and result_state. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate
if state:
if state.get("result_state"):
cur_phase = convert_to_flyte_phase(state["result_state"])
if state.get("state_message"):
message = state["state_message"]
life_cycle_state = state.get("life_cycle_state")
if result_state_is_available(life_cycle_state):
result_state = state.get("result_state")
cur_phase = convert_to_flyte_phase(result_state)
else:
cur_phase = convert_to_flyte_phase(life_cycle_state)

message = state.get("state_message")

job_id = response.get("job_id")
databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{resource_meta.run_id}"
Expand All @@ -112,4 +118,8 @@ def get_header() -> typing.Dict[str, str]:
return {"Authorization": f"Bearer {token}", "content-type": "application/json"}


def result_state_is_available(life_cycle_state: str) -> bool:
return life_cycle_state == "TERMINATED"


AgentRegistry.register(DatabricksAgent())
6 changes: 5 additions & 1 deletion plugins/flytekit-spark/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ async def test_databricks_agent():
)

mock_create_response = {"run_id": "123"}
mock_get_response = {"job_id": "1", "run_id": "123", "state": {"result_state": "SUCCESS", "state_message": "OK"}}
mock_get_response = {
"job_id": "1",
"run_id": "123",
"state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS", "state_message": "OK"},
}
mock_delete_response = {}
create_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit"
get_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=123"
Expand Down
5 changes: 5 additions & 0 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,17 @@ def test_convert_to_flyte_phase():
assert convert_to_flyte_phase("TIMEOUT") == TaskExecution.FAILED
assert convert_to_flyte_phase("TIMEDOUT") == TaskExecution.FAILED
assert convert_to_flyte_phase("CANCELED") == TaskExecution.FAILED
assert convert_to_flyte_phase("SKIPPED") == TaskExecution.FAILED
assert convert_to_flyte_phase("INTERNAL_ERROR") == TaskExecution.FAILED

assert convert_to_flyte_phase("DONE") == TaskExecution.SUCCEEDED
assert convert_to_flyte_phase("SUCCEEDED") == TaskExecution.SUCCEEDED
assert convert_to_flyte_phase("SUCCESS") == TaskExecution.SUCCEEDED

assert convert_to_flyte_phase("RUNNING") == TaskExecution.RUNNING
assert convert_to_flyte_phase("TERMINATING") == TaskExecution.RUNNING

assert convert_to_flyte_phase("PENDING") == TaskExecution.INITIALIZING

invalid_state = "INVALID_STATE"
with pytest.raises(Exception, match=f"Unrecognized state: {invalid_state.lower()}"):
Expand Down

0 comments on commit d1baacc

Please sign in to comment.