Skip to content

Commit

Permalink
Merge pull request #42 from SebastianScherer88/improve-pipeline-and-f…
Browse files Browse the repository at this point in the history
…low-queries

Improve pipeline and flow queries
  • Loading branch information
SebastianScherer88 authored Nov 5, 2024
2 parents ca1cc9f + ea84d71 commit e2a25a0
Show file tree
Hide file tree
Showing 5 changed files with 302 additions and 35 deletions.
19 changes: 19 additions & 0 deletions sdk/bettmensch_ai/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,22 @@ class COMPONENT_IMPLEMENTATION(Enum):
base: str = "base"
standard: str = "standard"
torch_ddp: str = "torch-ddp"


class FLOW_LABEL(Enum):
"""A utility class for valid Flow label keys"""

pipeline_name: str = "bettmensch.ai/pipeline-name"
pipeline_id: str = "bettmensch.ai/pipeline-id"
phase: str = "workflows.argoproj.io/phase"


class FLOW_PHASE(Enum):
"""A utility class for valid Flow phase label values"""

pending: str = "Pending"
running: str = "Running"
succeeded: str = "Succeeded"
failed: str = "Failed"
error: str = "Error"
unknown: str = ""
126 changes: 116 additions & 10 deletions sdk/bettmensch_ai/pipelines/flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Optional
from typing import Dict, List, Optional

from bettmensch_ai.constants import ARGO_NAMESPACE, FLOW_LABEL, FLOW_PHASE
from bettmensch_ai.pipelines.client import hera_client
from hera.workflows import Workflow
from hera.workflows.models import Workflow as WorkflowModel
Expand All @@ -18,8 +19,69 @@ def __init__(self, registered_flow: Workflow):

@property
def registered_name(self) -> str:
"""The unique name of the Flow (= argo Workflow)
Returns:
str: The name of the Flow
"""
return self.registered_flow.name

@property
def registered_namespace(self) -> str:
"""The namespace of the Flow (= argo Workflow)
Returns:
str: The namespace of the Flow
"""

return self.registered_flow.namespace

@property
def registered_pipeline(self) -> str:
"""The unique name of the registered pipeline (= argo WorkflowTemplate)
that the Flow originates from.
Returns:
str: The name of the parent Pipeline
"""
return self.registered_flow.workflow_template_ref.name

@property
def phase(self) -> str:
"""The current phase of the Flow (= argo Workflow)
Returns:
str: The phase of the Flow
"""

return self.registered_flow.status.phase

@property
def started_at(self) -> str:
"""The time the flow started (where applicable). Returns None if not
started yet.
E.g. "2024-11-05T13:04:19Z"
Returns:
str: The phase of the flow
"""

return self.registered_flow.status.started_at

@property
def finished_at(self) -> str:
"""The time the flow finished (where applicable). Returns None if not
finished yet.
E.g. "2024-11-05T13:04:19Z"
Returns:
str: The phase of the flow
"""

return self.registered_flow.status.finished_at

@classmethod
def from_workflow(cls, workflow: Workflow) -> "Flow":
"""Class method to initialize a Flow instance from a
Expand All @@ -36,7 +98,7 @@ def from_workflow(cls, workflow: Workflow) -> "Flow":


def get_flow(
registered_name: str, registered_namespace: Optional[str] = None
registered_name: str, registered_namespace: str = ARGO_NAMESPACE
) -> Flow:
"""Returns the specified Flow.
Expand All @@ -50,7 +112,7 @@ def get_flow(
"""

workflow_model: WorkflowModel = hera_client.get_workflow(
namespace=registered_name, name=registered_namespace
namespace=registered_namespace, name=registered_name
)

workflow: Workflow = Workflow.from_dict(workflow_model.dict())
Expand All @@ -61,21 +123,65 @@ def get_flow(


def list_flows(
registered_namespace: Optional[str] = None,
label_selector: Optional[str] = None,
field_selector: Optional[str] = None,
registered_namespace: str = ARGO_NAMESPACE,
registered_pipeline_name: Optional[str] = None,
phase: Optional[str] = None,
labels: Dict = {},
**kwargs,
) -> List[Flow]:
"""Lists all flows.
"""Get all flows that meet the query specifications.
Args:
registered_namespace (Optional[str], optional): The namespace in which
the underlying argo Workflow lives. Defaults to ARGO_NAMESPACE.
registered_pipeline_name (Optional[str], optional): Optional filter to
only consider Flows originating from the specified registered
Pipeline. Defaults to None, i.e. no pipeline-based filtering.
phase (Optional[str], optional): Optional filter to only consider Flows
that are in the specified phase. Defaults to None, i.e. no phase-
based filtering.
labels (Dict, optional): Optional filter to only consider Flows whose
underlying argo Workflow resource contains all of the specified
labels. Defaults to {}, i.e. no label-based filtering.
Returns:
List[Flow]: A list of all Flows that meet the query scope.
List[Flow]: A list of Flows that meet the filtering specifications.
"""

# build label selector
if (not labels) and (phase is None) and (registered_pipeline_name is None):
label_selector = None
else:
all_labels = labels.copy()

# add phase label
if phase is not None:
assert phase in (
FLOW_PHASE.error.value,
FLOW_PHASE.failed.value,
FLOW_PHASE.pending.value,
FLOW_PHASE.running.value,
FLOW_PHASE.succeeded.value,
FLOW_PHASE.unknown.value,
), f"Invalid phase spec: {phase}. Must be one of the constants.FLOW_PHASE levels." # noqa: E501
all_labels.update({FLOW_LABEL.phase.value: phase})

# add pipeline identifier label
if registered_pipeline_name is not None:
all_labels.update(
{
FLOW_LABEL.pipeline_name.value: registered_pipeline_name,
}
)

kv_label_list = list(all_labels.items()) # [('a',1),('b',2)]
label_selector = ",".join(
[f"{k}={v}" for k, v in kv_label_list]
) # "a=1,b=2"

response = hera_client.list_workflows(
namespace=registered_namespace,
label_selector=label_selector,
field_selector=field_selector,
**kwargs,
)

Expand All @@ -95,7 +201,7 @@ def list_flows(


def delete_flow(
registered_name: str, registered_namespace: Optional[str] = None, **kwargs
registered_name: str, registered_namespace: str = ARGO_NAMESPACE, **kwargs
) -> WorkflowDeleteResponseModel:
"""Deletes the specified Flow from the server.
Expand Down
97 changes: 83 additions & 14 deletions sdk/bettmensch_ai/pipelines/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import inspect
from typing import Any, Callable, Dict, List, Optional

from bettmensch_ai.constants import COMPONENT_IMPLEMENTATION, PIPELINE_TYPE
from bettmensch_ai.constants import (
ARGO_NAMESPACE,
COMPONENT_IMPLEMENTATION,
FLOW_LABEL,
PIPELINE_TYPE,
)
from bettmensch_ai.io import InputParameter, Parameter
from bettmensch_ai.pipelines.client import (
ArgoWorkflowsBackendConfiguration,
hera_client,
)
from bettmensch_ai.pipelines.flow import Flow, list_flows
from bettmensch_ai.pipelines.pipeline_context import (
PipelineContext,
_pipeline_context,
Expand Down Expand Up @@ -369,6 +375,10 @@ def run(
workflow_template_ref=pipeline_ref,
namespace=self.registered_namespace,
arguments=workflow_inputs,
labels={
FLOW_LABEL.pipeline_name.value: self.registered_name,
FLOW_LABEL.pipeline_id.value: self.registered_id,
},
)

registered_workflow: WorkflowModel = workflow.create(
Expand Down Expand Up @@ -416,26 +426,65 @@ def from_registry(
**kwargs,
)

def list_flows(
self,
phase: Optional[str] = None,
additional_labels: Dict = {},
**kwargs,
) -> List[Flow]:
"""Lists all Flows that originate from this Pipeline.
Args:
phase (Optional[str], optional): Optional filter to only consider
Flows that are in the specified phase. Defaults to None, i.e.
no phase-based filtering. This will be added to the labels.
additional_labels (Dict, optional): Optional filter to only
consider Flows whose underlying argo Workflow resource contains
all of the specified labels. Defaults to {}, i.e. no
label-based filtering, however the pipeline's name and id will
always be added automatically.
Returns:
List[Flow]: A list of Flows that meet the filtering specifications.
"""

# validate registration status of pipeline
if not self.registered:
raise ValueError(
"Pipeline needs to be registered first. Are you sure you have"
"ran `register`?"
)

return list_flows(
registered_namespace=self.registered_namespace,
registered_pipeline_name=self.registered_name,
phase=phase,
labels=additional_labels,
**kwargs,
)


def get_registered_pipeline(
registered_name: str,
registered_namespace: Optional[str] = None,
registered_namespace: Optional[str] = ARGO_NAMESPACE,
**kwargs,
) -> Pipeline:
"""Returns the specified registered Pipeline.
"""Get the registered pipeline.
Args:
registered_name (str): The `registered_name` of the Pipeline
(equivalent to the `name` of its underlying WorkflowTemplate).
registered_namespace (str): The `registered_namespace` of the Pipeline
(equivalent to the `namespace` of its underlying WorkflowTemplate).
(i.e. the name its underlying WorkflowTemplate).
registered_namespace (Optional[str], optional): The
`registered_namespace` of the Pipeline (i.e. the namespace of its
underlying WorkflowTemplate). Defaults to ARGO_NAMESPACE.
Returns:
Pipeline: The registered Pipeline instance.
Pipeline: A Pipeline object.
"""

workflow_template_model: WorkflowTemplateModel = (
hera_client.get_workflow_template(
namespace=registered_namespace, name=registered_name
namespace=registered_namespace, name=registered_name, **kwargs
)
)

Expand All @@ -449,24 +498,42 @@ def get_registered_pipeline(


def list_registered_pipelines(
registered_namespace: Optional[str] = None,
registered_namespace: str = ARGO_NAMESPACE,
registered_name_pattern: Optional[str] = None,
label_selector: Optional[str] = None,
field_selector: Optional[str] = None,
labels: Dict = {},
**kwargs,
) -> List[Pipeline]:
"""Lists all registered pipelines.
"""Get all registered pipelines that meet the query specification.
Args:
registered_namespace (Optional[str], optional): The namespace in which
the underlying argo WorkflowTemplate lives. Defaults to
ARGO_NAMESPACE.
registered_name_pattern (Optional[str], optional): The pattern to
filter the argo WorkflowTemplates' names against. Defaults to None,
i.e. no name-based filtering.
labels (Dict, optional): Optional filter to only consider Pipelines
whose underlying argo WorkflowTemplate resource contains all of the
specified labels. Defaults to {}, i.e. no label-based filtering.
Returns:
List[Pipeline]: A list of all registered Pipelines that meet the query
scope.
"""

# build label selector
if not labels:
label_selector = None
else:
kv_label_list = list(labels.items()) # [('a',1),('b',2)]
label_selector = ",".join(
[f"{k}={v}" for k, v in kv_label_list]
) # "a=1,b=2"

response = hera_client.list_workflow_templates(
namespace=registered_namespace,
name_pattern=registered_name_pattern,
label_selector=label_selector,
field_selector=field_selector,
**kwargs,
)

Expand All @@ -487,7 +554,9 @@ def list_registered_pipelines(


def delete_registered_pipeline(
registered_name: str, registered_namespace: Optional[str] = None, **kwargs
registered_name: str,
registered_namespace: Optional[str] = ARGO_NAMESPACE,
**kwargs,
) -> WorkflowTemplateDeleteResponseModel:
"""Deletes the specified registered Pipeline from the server.
Expand Down
18 changes: 16 additions & 2 deletions sdk/test/k8s/test_flow.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,25 @@
import pytest
from bettmensch_ai.pipelines import delete_flow, list_flows
from bettmensch_ai.pipelines import Flow, delete_flow, get_flow, list_flows


@pytest.mark.standard
@pytest.mark.ddp
@pytest.mark.delete_flows
@pytest.mark.order(10)
def test_get_standard_flow(test_namespace):
flows = list_flows(registered_namespace=test_namespace)

for flow in flows:
flow_reloaded = get_flow(
registered_name=flow.registered_name,
registered_namespace=test_namespace,
)
assert isinstance(flow_reloaded, Flow)


@pytest.mark.standard
@pytest.mark.ddp
@pytest.mark.delete_flows
@pytest.mark.order(11)
def test_delete(test_namespace):
"""Test the delete_flow function"""

Expand Down
Loading

0 comments on commit e2a25a0

Please sign in to comment.