Skip to content

Commit

Permalink
finish the i/o handling for server flow class
Browse files Browse the repository at this point in the history
  • Loading branch information
SebastianScherer88 committed Dec 7, 2024
1 parent 3fd9e56 commit b58625f
Showing 1 changed file with 84 additions and 199 deletions.
283 changes: 84 additions & 199 deletions sdk/bettmensch_ai/server/flow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from datetime import datetime
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple, Type, Union

from bettmensch_ai.server.pipeline import (
NodeArtifactInput,
Expand All @@ -18,7 +18,6 @@
ResourceTemplate,
ScriptTemplate,
)
from bettmensch_ai.server.utils import copy_non_null_dict
from hera.workflows.models import NodeStatus as NodeStatusModel
from hera.workflows.models import Workflow as WorkflowModel
from hera.workflows.models import WorkflowSpec as WorkflowSpecModel
Expand Down Expand Up @@ -191,29 +190,31 @@ def get_node_status_by_display_name(

@classmethod
def _get_model_classes(
cls,
pipeline_io: Union[
type[PipelineInputs],
type[PipelineOutputs],
type[NodeInputs],
type[NodeOutputs],
]
Type[PipelineInputs],
Type[PipelineOutputs],
Type[NodeInputs],
Type[NodeOutputs],
],
) -> Tuple[
Literal["inputs", "outputs"],
Union[
type[FlowParameterInput],
type[FlowParameterOutput],
type[FlowNodeParameterInput],
type[FlowNodeParameterOutput],
Type[FlowParameterInput],
Type[FlowParameterOutput],
Type[FlowNodeParameterInput],
Type[FlowNodeParameterOutput],
],
Union[
type[FlowArtifactOutput],
type[FlowNodeArtifactInput],
type[FlowNodeArtifactOutput],
Type[FlowArtifactOutput],
Type[FlowNodeArtifactInput],
Type[FlowNodeArtifactOutput],
],
Union[
type[FlowInputs],
type[FlowOutputs],
type[FlowNodeInputs],
type[FlowNodeOutputs],
Type[FlowInputs],
Type[FlowOutputs],
Type[FlowNodeInputs],
Type[FlowNodeOutputs],
],
]:
"""_summary_
Expand Down Expand Up @@ -249,17 +250,24 @@ def _get_model_classes(
"""

if pipeline_io == PipelineInputs:
return FlowParameterInput, None, FlowInputs
return "inputs", FlowParameterInput, None, FlowInputs
elif pipeline_io == PipelineOutputs:
return FlowParameterOutput, FlowArtifactOutput, FlowOutputs
return (
"outputs",
FlowParameterOutput,
FlowArtifactOutput,
FlowOutputs,
)
elif pipeline_io == NodeInputs:
return (
"inputs",
FlowNodeParameterInput,
FlowNodeArtifactInput,
FlowNodeInputs,
)
elif pipeline_io == NodeOutputs:
return (
"outputs",
FlowNodeParameterOutput,
FlowNodeArtifactOutput,
FlowNodeOutputs,
Expand All @@ -276,7 +284,6 @@ def _get_model_classes(
@classmethod
def _build_generic_io(
cls,
io_type: Literal["inputs", "outputs"],
pipeline_io: Union[
PipelineInputs,
PipelineOutputs,
Expand All @@ -286,34 +293,35 @@ def _build_generic_io(
flow_node: NodeStatusModel,
) -> Union[FlowInputs, FlowOutputs, FlowNodeInputs, FlowNodeOutputs]:
"""Parametrizable utility function to build either inputs or outputs
for the Flow or one of its FlowNodes.
for the Flow or one of its FlowNodes using Pipeline or Node I/Os.
Args:
io_type (Literal["inputs", "outputs"]): _description_
pipeline_io (Union[
PipelineInputs,
PipelineOutputs,
NodeInputs,
NodeOutputs,
]): _description_
flow_node (NodeStatusModel): _description_
flow_node (NodeStatusModel): A pipeline or pipeline node I/O
instance
Returns:
Union[FlowInputs, FlowOutputs, FlowNodeInputs, FlowNodeOutputs]:
_description_
A flow or flow node I/O instance
"""

io = {"parameters": [], "artifacts": []}

(
io_type,
parameter_io_class,
artifact_io_class,
io_class,
) = cls._get_model_classes(pipeline_io.__class__)

# parameters
for ppo in pipeline_io.parameters:
flow_parameter_output_data = ppo.model_dump()
flow_parameter_io_data = ppo.model_dump()
try:
fpo_value = [
fpo
Expand All @@ -323,29 +331,32 @@ def _build_generic_io(
except (AttributeError, IndexError):
fpo_value = None
finally:
flow_parameter_output_data["value"] = fpo_value
flow_parameter_io_data["value"] = fpo_value
flow_parameter_output = parameter_io_class.model_validate(
flow_parameter_output_data
flow_parameter_io_data
)
io["parameters"].append(flow_parameter_output)

# artifacts
for pao in pipeline_io.artifacts:
flow_artifact_output_data = pao.model_dump()
try:
fao_s3 = [
fao
for fao in getattr(flow_node, io_type).artifacts
if fao.name == pao.name
][0].s3.dict()
except (AttributeError, IndexError):
fao_s3 = None
finally:
flow_artifact_output_data["s3"] = fao_s3
flow_parameter_output = artifact_io_class.model_validate(
flow_artifact_output_data
)
io["artifacts"].append(flow_parameter_output)
if artifact_io_class is not None:
for pao in pipeline_io.artifacts:
flow_artifact_io_data = pao.model_dump()
try:
fao_s3 = [
fao
for fao in getattr(flow_node, io_type).artifacts
if fao.name == pao.name
][0].s3.dict()
except (AttributeError, IndexError):
fao_s3 = None
finally:
flow_artifact_io_data["s3"] = fao_s3
flow_parameter_output = artifact_io_class.model_validate(
flow_artifact_io_data
)
io["artifacts"].append(flow_parameter_output)
else:
del io["artifacts"]

return io_class.model_validate(io)

Expand Down Expand Up @@ -373,81 +384,13 @@ def build_io(
)

# --- inputs
flow_inputs = {
"parameters": [],
}

# parameters
for ppi in pipeline_inputs.parameters:
flow_parameter_input_data = ppi.model_dump()
try:
fpi_value = [
fpi
for fpi in inner_dag_node.inputs.parameters
if fpi.name == ppi.name
][0].value
except (AttributeError, IndexError):
fpi_value = None
finally:
flow_parameter_input_data["value"] = fpi_value
flow_parameter_input = FlowParameterInput.model_validate(
flow_parameter_input_data
)
flow_inputs["parameters"].append(flow_parameter_input)
flow_inputs = cls._build_generic_io(pipeline_inputs, inner_dag_node)

# --- outputs
flow_outputs = {"parameters": [], "artifacts": []}

# parameters
for ppo in pipeline_outputs.parameters:
flow_parameter_output_data = ppo.model_dump()
try:
fpo_value = [
fpo
for fpo in inner_dag_node.outputs.parameters
if fpo.name == ppo.name
][0].value
except (AttributeError, IndexError):
fpo_value = None
finally:
flow_parameter_output_data["value"] = fpo_value
flow_parameter_output = FlowParameterOutput.model_validate(
flow_parameter_output_data
)
flow_outputs["parameters"].append(flow_parameter_output)

# artifacts
for pao in pipeline_outputs.artifacts:
flow_artifact_output_data = pao.model_dump()
try:
fao_s3 = [
fao
for fao in inner_dag_node.outputs.artifacts
if fao.name == pao.name
][0].s3.dict()
except (AttributeError, IndexError):
fao_s3 = None
finally:
flow_artifact_output_data["s3"] = fao_s3
flow_parameter_output = FlowArtifactOutput.model_validate(
flow_artifact_output_data
)
flow_outputs["artifacts"].append(flow_parameter_output)

flow_outputs = FlowOutputs.model_validate(flow_outputs)
flow_outputs = cls._build_generic_io(pipeline_outputs, inner_dag_node)

return flow_inputs, flow_outputs

@classmethod
def build_flow_node_io(
cls,
pipeline_node_io: Union[NodeArtifactInput, NodeArtifactOutput],
flow_node: NodeStatusModel,
) -> Tuple[FlowNodeInputs, FlowNodeOutputs]:

if flow_node is not None:
flow_node = copy_non_null_dict(flow_node.dict())

@classmethod
def build_flow_node(
cls,
Expand All @@ -470,93 +413,35 @@ def build_flow_node(
"depends": pipeline_node.depends,
}

try:
flow_node = workflow_nodes_dict[pipeline_node.name]
except KeyError:
flow_node = None
flow_node = workflow_nodes_dict.get(pipeline_node.name, None)

flow_node_inputs, flow_node_outputs = cls.build_flow_node_io(
pipeline_node, flow_node
)
if flow_node is None:
flow_node_data["pod_name"] = pipeline_node.name
flow_node_data["phase"] = "Not Scheduled"
flow_node_inputs, flow_node_outputs = (
pipeline_node.inputs.model_dump(),
pipeline_node.outputs.model_dump(),
)
else:
flow_node_data["id"] = flow_node.id
flow_node_data["type"] = flow_node.type
flow_node_data["pod_name"] = flow_node.name
flow_node_data["phase"] = flow_node.phase
flow_node_data["dependants"] = getattr(flow_node, "children", None)
flow_node_data["host_node_name"] = getattr(
flow_node, "host_node_name", None
)
flow_node_inputs = cls._build_generic_io(
pipeline_node.inputs, flow_node
)
flow_node_outputs = cls._build_generic_io(
pipeline_node.outputs, flow_node
)

flow_node_data["inputs"] = flow_node_inputs
flow_node_data["outputs"] = flow_node_outputs

flow_node_data

# --------------
# flow_node_dict = {
# "name": pipeline_node.name,
# "template": pipeline_node.template,
# "inputs": pipeline_node.inputs.model_dump(),
# "depends": pipeline_node.depends,
# }

# try:
# workflow_node_dict = workflow_nodes_dict[pipeline_node.name].
# dict()
# workflow_node_dict = copy_non_null_dict(workflow_node_dict)
# except KeyError:
# flow_node_dict["pod_name"] = pipeline_node.name
# flow_node_dict["phase"] = "Not Scheduled"
# flow_node_dict["outputs"] = FlowNodeOutputs(
# **pipeline_node.outputs.model_dump(),
# ).model_dump()
# flow_node_dict["logs"] = None
# else:
# flow_node_dict["id"] = workflow_node_dict["id"]
# flow_node_dict["type"] = workflow_node_dict["type"]
# flow_node_dict["pod_name"] = workflow_node_dict["name"]
# flow_node_dict["phase"] = workflow_node_dict["phase"]
# flow_node_dict["outputs"] = FlowNodeOutputs(
# exit_code=workflow_node_dict.get("exit_code", None),
# **pipeline_node.outputs.model_dump(),
# ).model_dump()
# flow_node_dict["dependants"] = workflow_node_dict.get(
# "children", None
# )
# flow_node_dict["host_node_name"] = workflow_node_dict.get(
# "host_node_name", None
# )

# # inject resolved input/output values where possible
# for argument_io in ("inputs", "outputs"):
# for argument_type in ("parameters", "artifacts"):
# try:
# workflow_node_arguments = workflow_node_dict[
# argument_io
# ][argument_type]
# flow_node_arguments = flow_node_dict[argument_io][
# argument_type
# ]

# if workflow_node_arguments is None:
# continue
# else:
# for i, argument in enumerate(
# workflow_node_arguments
# ):
# if i < len(flow_node_arguments):
# if (
# flow_node_arguments[i]["name"]
# == argument["name"]
# ):
# if argument_type == "parameters":
# flow_node_arguments[i][
# "value"
# ] = argument["value"]
# elif argument_type == "artifacts":
# flow_node_arguments[i]["s3"] = {
# "key": argument["s3"]["key"],
# "bucket": argument["s3"][
# "bucket"
# ],
# }
# elif argument["name"] == "main-logs":
# flow_node_dict["logs"] = argument
# else:
# pass
# except KeyError:
# pass
# finally:
# flow_node = FlowNode(**flow_node_dict)
flow_node = FlowNode.model_validate(flow_node_data)

return flow_node

Expand Down

0 comments on commit b58625f

Please sign in to comment.