Skip to content

Commit

Permalink
Merge pull request #40 from SebastianScherer88/fix-k8s-service-naming…
Browse files Browse the repository at this point in the history
…-issue-for-ddp-components

Fix k8s service naming issue for ddp components
  • Loading branch information
SebastianScherer88 authored Oct 29, 2024
2 parents 133d32d + 2fd1a5e commit 400baa7
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 32 deletions.
3 changes: 2 additions & 1 deletion sdk/bettmensch_ai/components/base_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class BaseComponent(object):
name: str = None
func: Callable = None
base_name: str = None
name: str = None
name: str = None # this attribute will hold the name of the ArgoWorkflow
# Task and the Template of this component
hera_template_kwargs: Dict = {}
template_inputs: Dict[str, Union[InputParameter, InputArtifact]] = None
template_outputs: Dict[str, Union[OutputParameter, OutputArtifact]] = None
Expand Down
15 changes: 5 additions & 10 deletions sdk/bettmensch_ai/components/torch_ddp_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import inspect
import textwrap
from typing import Callable, Dict, List, Optional, Union
from uuid import uuid4

from bettmensch_ai.components.base_component import BaseComponent
from bettmensch_ai.components.base_inline_script_runner import (
Expand Down Expand Up @@ -87,7 +86,7 @@ def generate_source(self, instance: Script) -> str:
# add function definition and decoration with `torch_ddp`
torch_ddp_decoration = [
"\nfrom bettmensch_ai.components import torch_ddp\n",
"torch_ddp_decorator=torch_ddp()\n"
"torch_ddp_decorator=torch_ddp()\n",
f"""torch_ddp_function=torch_ddp_decorator({
instance.source.__name__
})\n""",
Expand Down Expand Up @@ -120,7 +119,6 @@ class TorchDDPComponent(BaseComponent):
nproc_per_node: int
service_templates: Dict[str, Callable] = None
k8s_namespace: str = ARGO_NAMESPACE
k8s_service_name: str = ""

# if no resources are specified, set minimal requirements derived from
# testing the ddp example on K8s
Expand Down Expand Up @@ -165,11 +163,11 @@ def build_service_templates(self) -> Dict[str, Callable]:
return {
"create": create_torch_ddp_service_template(
component_base_name=self.base_name,
service_name=self.k8s_service_name,
component_task_name=self.name,
),
"delete": delete_torch_ddp_service_template(
component_base_name=self.base_name,
service_name=self.k8s_service_name,
component_task_name=self.name,
),
}

Expand Down Expand Up @@ -232,8 +230,7 @@ def build_script_decorator_kwargs(self, torch_node_rank: int) -> Dict:
),
Env(
name=f"{LaunchConfigSettings.model_config['env_prefix']}rdzv_endpoint_url", # noqa: E501
value=f"{self.k8s_service_name}.{self.k8s_namespace}"
+ ".svc.cluster.local",
value=f"{self.name}-{{{{workflow.uid}}}}.{self.k8s_namespace}.svc.cluster.local", # noqa: E501
),
Env(
name=f"{LaunchConfigSettings.model_config['env_prefix']}rdzv_endpoint_port", # noqa: E501
Expand All @@ -255,7 +252,7 @@ def build_script_decorator_kwargs(self, torch_node_rank: int) -> Dict:
labels.update(
{
"torch-node": torch_node_rank,
"torch-job": self.k8s_service_name,
"torch-job": self.name,
}
)
script_decorator_kwargs["labels"] = labels
Expand All @@ -273,8 +270,6 @@ def build_hera_task_factory(self) -> List[Callable]:
node in the distributed torch run.
"""

self.k8s_service_name = f"{self.name}-{uuid4()}"

# add torch run environment variables to script kwargs
task_factory = []

Expand Down
14 changes: 8 additions & 6 deletions sdk/bettmensch_ai/components/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def wrapper(*function_args):

def create_torch_ddp_service_template(
component_base_name: str,
service_name: str,
component_task_name: str,
namespace: str = ARGO_NAMESPACE, # noqa: E501
) -> Resource:
"""Utility for a template creating the service resource required for
Expand All @@ -152,26 +152,28 @@ def create_torch_ddp_service_template(
manifest=f"""apiVersion: v1
kind: Service
metadata:
name: {service_name}
name: {component_task_name}-{{{{workflow.uid}}}}
namespace: {namespace}
labels:
app: {service_name}
workflows.argoproj.io/workflow: {{{{workflow.name}}}}
torch-job: {component_task_name}
spec:
clusterIP: None # ClusterIP set to None for headless service.
ports:
- name: {DDP_PORT_NAME} # Port for torchrun master<->worker node coms.
port: {DDP_PORT_NUMBER}
targetPort: {DDP_PORT_NUMBER}
selector:
torch-job: {service_name}
workflows.argoproj.io/workflow: {{{{workflow.name}}}}
torch-job: {component_task_name}
torch-node: '0' # Selector for pods associated with this service.
""",
)


def delete_torch_ddp_service_template(
component_base_name: str,
service_name: str,
component_task_name: str,
namespace: str = ARGO_NAMESPACE, # noqa: E501
) -> Resource:
"""Utility for a template deleting the service resource required for
Expand All @@ -183,7 +185,7 @@ def delete_torch_ddp_service_template(
flags=[
"service",
"--selector",
f"app={service_name}",
f"torch-job={component_task_name},workflows.argoproj.io/workflow={{{{workflow.name}}}}", # noqa: E501
"-n",
namespace,
],
Expand Down
8 changes: 2 additions & 6 deletions sdk/test/unit/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,7 @@ def parameter_to_artifact(
]

assert wft.templates[3].labels["torch-node"] == "0"
assert (
wft.templates[3]
.labels["torch-job"]
.startswith("convert-to-artifact-0-")
)
assert wft.templates[3].labels["torch-job"] == "convert-to-artifact-0"
assert wft.templates[4].labels["torch-node"] == "1"

parameter_to_artifact.export(test_output_dir)
Expand Down Expand Up @@ -146,7 +142,7 @@ def adding_parameters(
]

assert wft.templates[3].labels["torch-node"] == "0"
assert wft.templates[3].labels["torch-job"].startswith("a-plus-b-0-")
assert wft.templates[3].labels["torch-job"] == "a-plus-b-0"
assert wft.templates[4].labels["torch-node"] == "1"

adding_parameters.export(test_output_dir)
24 changes: 15 additions & 9 deletions sdk/test/unit/test_torch_ddp_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,9 @@ def test_function(
assert test_component.task_factory is None


def test_parameter_torch_component_to_hera(test_mock_pipeline):
def test_parameter_torch_component_to_hera(
test_output_dir, test_mock_pipeline
):
"""Declaration of Component using InputParameter and OutputParameter"""

# mock active pipeline with 2 inputs
Expand Down Expand Up @@ -255,7 +257,7 @@ def test_parameter_torch_component_to_hera(test_mock_pipeline):
a_plus_b_plus_2.task_factory = a_plus_b_plus_2.build_hera_task_factory()

with WorkflowTemplate(
name="test-parameter-component-workflow-template",
name="test-parameter-torch-ddp-component-workflow-template",
entrypoint="test_dag",
namespace="argo",
arguments=[
Expand All @@ -273,6 +275,8 @@ def test_parameter_torch_component_to_hera(test_mock_pipeline):
a_plus_b.to_hera()
a_plus_b_plus_2.to_hera()

wft.to_file(test_output_dir)

task_names = [task.name for task in wft.templates[4].tasks]
assert task_names == [
"a-plus-b-create-torch-ddp-service",
Expand All @@ -299,17 +303,18 @@ def test_parameter_torch_component_to_hera(test_mock_pipeline):
]

assert wft.templates[5].labels["torch-node"] == "0"
assert wft.templates[5].labels["torch-job"].startswith("a-plus-b-0-")
assert wft.templates[5].labels["torch-job"] == "a-plus-b-0"
assert wft.templates[6].labels["torch-node"] == "1"

assert wft.templates[7].labels["torch-node"] == "0"
assert (
wft.templates[7].labels["torch-job"].startswith("a-plus-b-plus-2-0-")
wft.templates[7].labels["torch-job"] == "a-plus-b-plus-2-0"
) # noqa: E501
assert wft.templates[8].labels["torch-node"] == "1"


def test_artifact_torch_component_to_hera(
test_output_dir,
test_mock_pipeline,
):

Expand Down Expand Up @@ -346,7 +351,7 @@ def test_artifact_torch_component_to_hera(
show.task_factory = show.build_hera_task_factory()

with WorkflowTemplate(
name="test-artifact-component-workflow-template",
name="test-artifact-torch-ddp-component-workflow-template",
entrypoint="test_dag",
namespace="argo",
arguments=[
Expand All @@ -361,6 +366,8 @@ def test_artifact_torch_component_to_hera(
convert.to_hera()
show.to_hera()

wft.to_file(test_output_dir)

task_names = [task.name for task in wft.templates[4].tasks]
assert task_names == [
"convert-parameters-create-torch-ddp-service",
Expand All @@ -386,11 +393,10 @@ def test_artifact_torch_component_to_hera(

assert wft.templates[5].labels["torch-node"] == "0"
assert (
wft.templates[5]
.labels["torch-job"]
.startswith("convert-parameters-0-") # noqa: E501
wft.templates[5].labels["torch-job"]
== "convert-parameters-0" # noqa: E501
)
assert wft.templates[6].labels["torch-node"] == "1"

assert wft.templates[7].labels["torch-node"] == "0"
assert wft.templates[7].labels["torch-job"].startswith("show-artifacts-0-")
assert wft.templates[7].labels["torch-job"] == "show-artifacts-0"

0 comments on commit 400baa7

Please sign in to comment.