Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Databricks Agent Phase #2244

Merged
merged 4 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions flytekit/extend/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ def convert_to_flyte_phase(state: str) -> TaskExecution.Phase:
Convert the state from the agent to the phase in flyte.
"""
state = state.lower()
# timedout is the state of Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate
if state in ["failed", "timeout", "timedout", "canceled"]:
# timedout, skipped, internal_error, terminating and pending is the state of a Databricks job. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate
pingsutw marked this conversation as resolved.
Show resolved Hide resolved
if state in ["failed", "timeout", "timedout", "canceled", "skipped", "internal_error"]:
return TaskExecution.FAILED
elif state in ["done", "succeeded", "success"]:
return TaskExecution.SUCCEEDED
elif state in ["running"]:
elif state in ["running", "terminating"]:
return TaskExecution.RUNNING
elif state in ["pending"]:
return TaskExecution.INITIALIZING
raise ValueError(f"Unrecognized state: {state}")


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def get(self, resource_meta: BigQueryMetadata, **kwargs) -> Resource:
output_location = f"bq://{dst.project}:{dst.dataset_id}.{dst.table_id}"
res = TypeEngine.dict_to_literal_map(ctx, {"results": StructuredDataset(uri=output_location)})

return Resource(phase=cur_phase, message=job.state, log_links=[log_link], outputs=res)
return Resource(phase=cur_phase, message=str(job.state), log_links=[log_link], outputs=res)

def delete(self, resource_meta: BigQueryMetadata, **kwargs):
client = bigquery.Client()
Expand Down
20 changes: 15 additions & 5 deletions plugins/flytekit-spark/flytekitplugins/spark/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,20 @@
raise Exception(f"Failed to get databricks job {resource_meta.run_id} with error: {resp.reason}")
response = await resp.json()

cur_phase = TaskExecution.RUNNING
cur_phase = TaskExecution.UNDEFINED
message = ""
state = response.get("state")

# The databricks job's state is determined by life_cycle_state and result_state. https://docs.databricks.com/en/workflows/jobs/jobs-2.0-api.html#runresultstate
if state:
if state.get("result_state"):
cur_phase = convert_to_flyte_phase(state["result_state"])
if state.get("state_message"):
message = state["state_message"]
life_cycle_state = state.get("life_cycle_state")
if result_state_is_available(life_cycle_state):
result_state = state.get("result_state")
cur_phase = convert_to_flyte_phase(result_state)
else:
cur_phase = convert_to_flyte_phase(life_cycle_state)

Check warning on line 95 in plugins/flytekit-spark/flytekitplugins/spark/agent.py

View check run for this annotation

Codecov / codecov/patch

plugins/flytekit-spark/flytekitplugins/spark/agent.py#L95

Added line #L95 was not covered by tests

message = state.get("state_message")

job_id = response.get("job_id")
databricks_console_url = f"https://{databricks_instance}/#job/{job_id}/run/{resource_meta.run_id}"
Expand All @@ -112,4 +118,8 @@
return {"Authorization": f"Bearer {token}", "content-type": "application/json"}


def result_state_is_available(life_cycle_state: str) -> bool:
return life_cycle_state == "TERMINATED"


AgentRegistry.register(DatabricksAgent())
6 changes: 5 additions & 1 deletion plugins/flytekit-spark/tests/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,11 @@ async def test_databricks_agent():
)

mock_create_response = {"run_id": "123"}
mock_get_response = {"job_id": "1", "run_id": "123", "state": {"result_state": "SUCCESS", "state_message": "OK"}}
mock_get_response = {
"job_id": "1",
"run_id": "123",
"state": {"life_cycle_state": "TERMINATED", "result_state": "SUCCESS", "state_message": "OK"},
}
mock_delete_response = {}
create_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/submit"
get_url = f"https://test-account.cloud.databricks.com{DATABRICKS_API_ENDPOINT}/runs/get?run_id=123"
Expand Down
5 changes: 5 additions & 0 deletions tests/flytekit/unit/extend/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,12 +288,17 @@ def test_convert_to_flyte_phase():
assert convert_to_flyte_phase("TIMEOUT") == TaskExecution.FAILED
assert convert_to_flyte_phase("TIMEDOUT") == TaskExecution.FAILED
assert convert_to_flyte_phase("CANCELED") == TaskExecution.FAILED
assert convert_to_flyte_phase("SKIPPED") == TaskExecution.FAILED
assert convert_to_flyte_phase("INTERNAL_ERROR") == TaskExecution.FAILED

assert convert_to_flyte_phase("DONE") == TaskExecution.SUCCEEDED
assert convert_to_flyte_phase("SUCCEEDED") == TaskExecution.SUCCEEDED
assert convert_to_flyte_phase("SUCCESS") == TaskExecution.SUCCEEDED

assert convert_to_flyte_phase("RUNNING") == TaskExecution.RUNNING
assert convert_to_flyte_phase("TERMINATING") == TaskExecution.RUNNING

assert convert_to_flyte_phase("PENDING") == TaskExecution.INITIALIZING

invalid_state = "INVALID_STATE"
with pytest.raises(Exception, match=f"Unrecognized state: {invalid_state.lower()}"):
Expand Down
Loading