diff --git a/sdk/bettmensch_ai/server/flow.py b/sdk/bettmensch_ai/server/flow.py index f6c1600..1a100ed 100644 --- a/sdk/bettmensch_ai/server/flow.py +++ b/sdk/bettmensch_ai/server/flow.py @@ -1,7 +1,7 @@ from __future__ import annotations from datetime import datetime -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple, Type, Union from bettmensch_ai.server.pipeline import ( NodeArtifactInput, @@ -18,7 +18,6 @@ ResourceTemplate, ScriptTemplate, ) -from bettmensch_ai.server.utils import copy_non_null_dict from hera.workflows.models import NodeStatus as NodeStatusModel from hera.workflows.models import Workflow as WorkflowModel from hera.workflows.models import WorkflowSpec as WorkflowSpecModel @@ -191,29 +190,31 @@ def get_node_status_by_display_name( @classmethod def _get_model_classes( + cls, pipeline_io: Union[ - type[PipelineInputs], - type[PipelineOutputs], - type[NodeInputs], - type[NodeOutputs], - ] + Type[PipelineInputs], + Type[PipelineOutputs], + Type[NodeInputs], + Type[NodeOutputs], + ], ) -> Tuple[ + Literal["inputs", "outputs"], Union[ - type[FlowParameterInput], - type[FlowParameterOutput], - type[FlowNodeParameterInput], - type[FlowNodeParameterOutput], + Type[FlowParameterInput], + Type[FlowParameterOutput], + Type[FlowNodeParameterInput], + Type[FlowNodeParameterOutput], ], Union[ - type[FlowArtifactOutput], - type[FlowNodeArtifactInput], - type[FlowNodeArtifactOutput], + Type[FlowArtifactOutput], + Type[FlowNodeArtifactInput], + Type[FlowNodeArtifactOutput], ], Union[ - type[FlowInputs], - type[FlowOutputs], - type[FlowNodeInputs], - type[FlowNodeOutputs], + Type[FlowInputs], + Type[FlowOutputs], + Type[FlowNodeInputs], + Type[FlowNodeOutputs], ], ]: """_summary_ @@ -249,17 +250,24 @@ def _get_model_classes( """ if pipeline_io == PipelineInputs: - return FlowParameterInput, None, FlowInputs + return "inputs", FlowParameterInput, None, FlowInputs elif pipeline_io == PipelineOutputs: - return FlowParameterOutput, FlowArtifactOutput, FlowOutputs + return ( + "outputs", + FlowParameterOutput, + FlowArtifactOutput, + FlowOutputs, + ) elif pipeline_io == NodeInputs: return ( + "inputs", FlowNodeParameterInput, FlowNodeArtifactInput, FlowNodeInputs, ) elif pipeline_io == NodeOutputs: return ( + "outputs", FlowNodeParameterOutput, FlowNodeArtifactOutput, FlowNodeOutputs, @@ -276,7 +284,6 @@ def _get_model_classes( @classmethod def _build_generic_io( cls, - io_type: Literal["inputs", "outputs"], pipeline_io: Union[ PipelineInputs, PipelineOutputs, @@ -286,26 +293,27 @@ def _build_generic_io( flow_node: NodeStatusModel, ) -> Union[FlowInputs, FlowOutputs, FlowNodeInputs, FlowNodeOutputs]: """Parametrizable utility function to build either inputs or outputs - for the Flow or one of its FlowNodes. + for the Flow or one of its FlowNodes using Pipeline or Node I/Os. Args: - io_type (Literal["inputs", "outputs"]): _description_ pipeline_io (Union[ PipelineInputs, PipelineOutputs, NodeInputs, NodeOutputs, ]): _description_ - flow_node (NodeStatusModel): _description_ + flow_node (NodeStatusModel): A pipeline or pipeline node I/O + instance Returns: Union[FlowInputs, FlowOutputs, FlowNodeInputs, FlowNodeOutputs]: - _description_ + A flow or flow node I/O instance """ io = {"parameters": [], "artifacts": []} ( + io_type, parameter_io_class, artifact_io_class, io_class, @@ -313,7 +321,7 @@ def _build_generic_io( # parameters for ppo in pipeline_io.parameters: - flow_parameter_output_data = ppo.model_dump() + flow_parameter_io_data = ppo.model_dump() try: fpo_value = [ fpo @@ -323,29 +331,32 @@ def _build_generic_io( except (AttributeError, IndexError): fpo_value = None finally: - flow_parameter_output_data["value"] = fpo_value + flow_parameter_io_data["value"] = fpo_value flow_parameter_output = parameter_io_class.model_validate( - flow_parameter_output_data + flow_parameter_io_data ) io["parameters"].append(flow_parameter_output) # artifacts - for pao in pipeline_io.artifacts: - flow_artifact_output_data = pao.model_dump() - try: - fao_s3 = [ - fao - for fao in getattr(flow_node, io_type).artifacts - if fao.name == pao.name - ][0].s3.dict() - except (AttributeError, IndexError): - fao_s3 = None - finally: - flow_artifact_output_data["s3"] = fao_s3 - flow_parameter_output = artifact_io_class.model_validate( - flow_artifact_output_data - ) - io["artifacts"].append(flow_parameter_output) + if artifact_io_class is not None: + for pao in pipeline_io.artifacts: + flow_artifact_io_data = pao.model_dump() + try: + fao_s3 = [ + fao + for fao in getattr(flow_node, io_type).artifacts + if fao.name == pao.name + ][0].s3.dict() + except (AttributeError, IndexError): + fao_s3 = None + finally: + flow_artifact_io_data["s3"] = fao_s3 + flow_parameter_output = artifact_io_class.model_validate( + flow_artifact_io_data + ) + io["artifacts"].append(flow_parameter_output) + else: + del io["artifacts"] return io_class.model_validate(io) @@ -373,81 +384,13 @@ def build_io( ) # --- inputs - flow_inputs = { - "parameters": [], - } - - # parameters - for ppi in pipeline_inputs.parameters: - flow_parameter_input_data = ppi.model_dump() - try: - fpi_value = [ - fpi - for fpi in inner_dag_node.inputs.parameters - if fpi.name == ppi.name - ][0].value - except (AttributeError, IndexError): - fpi_value = None - finally: - flow_parameter_input_data["value"] = fpi_value - flow_parameter_input = FlowParameterInput.model_validate( - flow_parameter_input_data - ) - flow_inputs["parameters"].append(flow_parameter_input) + flow_inputs = cls._build_generic_io(pipeline_inputs, inner_dag_node) # --- outputs - flow_outputs = {"parameters": [], "artifacts": []} - - # parameters - for ppo in pipeline_outputs.parameters: - flow_parameter_output_data = ppo.model_dump() - try: - fpo_value = [ - fpo - for fpo in inner_dag_node.outputs.parameters - if fpo.name == ppo.name - ][0].value - except (AttributeError, IndexError): - fpo_value = None - finally: - flow_parameter_output_data["value"] = fpo_value - flow_parameter_output = FlowParameterOutput.model_validate( - flow_parameter_output_data - ) - flow_outputs["parameters"].append(flow_parameter_output) - - # artifacts - for pao in pipeline_outputs.artifacts: - flow_artifact_output_data = pao.model_dump() - try: - fao_s3 = [ - fao - for fao in inner_dag_node.outputs.artifacts - if fao.name == pao.name - ][0].s3.dict() - except (AttributeError, IndexError): - fao_s3 = None - finally: - flow_artifact_output_data["s3"] = fao_s3 - flow_parameter_output = FlowArtifactOutput.model_validate( - flow_artifact_output_data - ) - flow_outputs["artifacts"].append(flow_parameter_output) - - flow_outputs = FlowOutputs.model_validate(flow_outputs) + flow_outputs = cls._build_generic_io(pipeline_outputs, inner_dag_node) return flow_inputs, flow_outputs - @classmethod - def build_flow_node_io( - cls, - pipeline_node_io: Union[NodeArtifactInput, NodeArtifactOutput], - flow_node: NodeStatusModel, - ) -> Tuple[FlowNodeInputs, FlowNodeOutputs]: - - if flow_node is not None: - flow_node = copy_non_null_dict(flow_node.dict()) - @classmethod def build_flow_node( cls, @@ -470,93 +413,35 @@ def build_flow_node( "depends": pipeline_node.depends, } - try: - flow_node = workflow_nodes_dict[pipeline_node.name] - except KeyError: - flow_node = None + flow_node = workflow_nodes_dict.get(pipeline_node.name, None) - flow_node_inputs, flow_node_outputs = cls.build_flow_node_io( - pipeline_node, flow_node - ) + if flow_node is None: + flow_node_data["pod_name"] = pipeline_node.name + flow_node_data["phase"] = "Not Scheduled" + flow_node_inputs, flow_node_outputs = ( + pipeline_node.inputs.model_dump(), + pipeline_node.outputs.model_dump(), + ) + else: + flow_node_data["id"] = flow_node.id + flow_node_data["type"] = flow_node.type + flow_node_data["pod_name"] = flow_node.name + flow_node_data["phase"] = flow_node.phase + flow_node_data["dependants"] = getattr(flow_node, "children", None) + flow_node_data["host_node_name"] = getattr( + flow_node, "host_node_name", None + ) + flow_node_inputs = cls._build_generic_io( + pipeline_node.inputs, flow_node + ) + flow_node_outputs = cls._build_generic_io( + pipeline_node.outputs, flow_node + ) + + flow_node_data["inputs"] = flow_node_inputs + flow_node_data["outputs"] = flow_node_outputs - flow_node_data - - # -------------- - # flow_node_dict = { - # "name": pipeline_node.name, - # "template": pipeline_node.template, - # "inputs": pipeline_node.inputs.model_dump(), - # "depends": pipeline_node.depends, - # } - - # try: - # workflow_node_dict = workflow_nodes_dict[pipeline_node.name]. - # dict() - # workflow_node_dict = copy_non_null_dict(workflow_node_dict) - # except KeyError: - # flow_node_dict["pod_name"] = pipeline_node.name - # flow_node_dict["phase"] = "Not Scheduled" - # flow_node_dict["outputs"] = FlowNodeOutputs( - # **pipeline_node.outputs.model_dump(), - # ).model_dump() - # flow_node_dict["logs"] = None - # else: - # flow_node_dict["id"] = workflow_node_dict["id"] - # flow_node_dict["type"] = workflow_node_dict["type"] - # flow_node_dict["pod_name"] = workflow_node_dict["name"] - # flow_node_dict["phase"] = workflow_node_dict["phase"] - # flow_node_dict["outputs"] = FlowNodeOutputs( - # exit_code=workflow_node_dict.get("exit_code", None), - # **pipeline_node.outputs.model_dump(), - # ).model_dump() - # flow_node_dict["dependants"] = workflow_node_dict.get( - # "children", None - # ) - # flow_node_dict["host_node_name"] = workflow_node_dict.get( - # "host_node_name", None - # ) - - # # inject resolved input/output values where possible - # for argument_io in ("inputs", "outputs"): - # for argument_type in ("parameters", "artifacts"): - # try: - # workflow_node_arguments = workflow_node_dict[ - # argument_io - # ][argument_type] - # flow_node_arguments = flow_node_dict[argument_io][ - # argument_type - # ] - - # if workflow_node_arguments is None: - # continue - # else: - # for i, argument in enumerate( - # workflow_node_arguments - # ): - # if i < len(flow_node_arguments): - # if ( - # flow_node_arguments[i]["name"] - # == argument["name"] - # ): - # if argument_type == "parameters": - # flow_node_arguments[i][ - # "value" - # ] = argument["value"] - # elif argument_type == "artifacts": - # flow_node_arguments[i]["s3"] = { - # "key": argument["s3"]["key"], - # "bucket": argument["s3"][ - # "bucket" - # ], - # } - # elif argument["name"] == "main-logs": - # flow_node_dict["logs"] = argument - # else: - # pass - # except KeyError: - # pass - # finally: - # flow_node = FlowNode(**flow_node_dict) + flow_node = FlowNode.model_validate(flow_node_data) return flow_node