From f849a0fee49cf4cf496dffc3ca1f0e4706e098ef Mon Sep 17 00:00:00 2001 From: byhsu Date: Sun, 2 Apr 2023 21:28:30 -0700 Subject: [PATCH 1/6] Add horovod task to mpi plugin Signed-off-by: byhsu --- .../flytekitplugins/kfmpi/task.py | 72 ++++++++++++++++++- 1 file changed, 70 insertions(+), 2 deletions(-) diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 6f207b421d..7ec8efd4ed 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -12,7 +12,10 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import TaskPlugins from flytekit.models import common as _common - +from flytekit.core.pod_template import PodTemplate, PRIMARY_CONTAINER_DEFAULT_NAME +from kubernetes.client.models import V1PodSpec, V1Volume, V1EmptyDirVolumeSource, V1Container, V1VolumeMount +from flytekit.lnkd.pod_template import LIPodTemplate +from flytekit.tools.translator import get_command_prefix_for_fast_execute class MPIJobModel(_common.FlyteIdlEntity): """Model definition for MPI the plugin @@ -80,7 +83,6 @@ class MPIJob(object): num_launcher_replicas: int = 1 num_workers: int = 1 - class MPIFunctionTask(PythonFunctionTask[MPIJob]): """ Plugin that submits a MPIJob (see https://github.com/kubeflow/mpi-operator) @@ -133,5 +135,71 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return MessageToDict(job.to_flyte_idl()) +@dataclass +class HorovodJob(object): + slots: int + num_launcher_replicas: int = 1 + num_workers: int = 1 + +class HorovodFunctionTask(PythonFunctionTask[HorovodJob]): + """ + For more info, check out https://github.com/horovod/horovod + """ + _MPI_JOB_TASK_TYPE = "mpi" + + # Customize your setup here. Please ensure the cmd, path, volume, etc are available in the pod. + ssh_command = "/usr/sbin/sshd -De -f /home/jobuser/.sshd_config" + ssh_auth_mount_path = "/home/jobuser/.ssh" + discovery_script_path = "/etc/mpi/discover_hosts.sh" + + def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): + + super().__init__( + task_config=task_config, + task_function=task_function, + task_type=self._MPI_JOB_TASK_TYPE, + **kwargs, + ) + + def get_command(self, settings: SerializationSettings) -> List[str]: + cmd = super().get_command(settings) + mpi_cmd = self._get_horovod_prefix() + cmd + return mpi_cmd + + def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: + job = MPIJobModel( + num_workers=self.task_config.num_workers, + num_launcher_replicas=self.task_config.num_launcher_replicas, + slots=self.task_config.slots, + ) + return MessageToDict(job.to_flyte_idl()) + + def get_config(self, settings: SerializationSettings) -> Dict[str, str]: + config = super().get_config(settings) + return {**config, "worker_spec_command": self.ssh_command} + + def _get_horovod_prefix(self) -> List[str]: + np = self.task_config.num_workers * self.task_config.slots + base_cmd = [ + "horovodrun", + "-np", + f"{np}", + "--verbose", + "--log-level", + "INFO", + "--network-interface", + "eth0", + "--min-np", + f"{np}", + "--max-np", + f"{np}", + "--slots-per-host", + f"{self.task_config.slots}", + "--host-discovery-script", + self.discovery_script_path, + ] + return base_cmd + # Register the MPI Plugin into the flytekit core plugin system TaskPlugins.register_pythontask_plugin(MPIJob, MPIFunctionTask) +TaskPlugins.register_pythontask_plugin(HorovodJob, HorovodFunctionTask) \ No newline at end of file From 465eef04a825c1e090fa84dccb6f3e62c74bedfa Mon Sep 17 00:00:00 2001 From: byhsu Date: Sun, 2 Apr 2023 21:29:01 -0700 Subject: [PATCH 2/6] Remove unused Signed-off-by: byhsu --- plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 7ec8efd4ed..6befb9d667 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -12,9 +12,6 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import TaskPlugins from flytekit.models import common as _common -from flytekit.core.pod_template import PodTemplate, PRIMARY_CONTAINER_DEFAULT_NAME -from kubernetes.client.models import V1PodSpec, V1Volume, V1EmptyDirVolumeSource, V1Container, V1VolumeMount -from flytekit.lnkd.pod_template import LIPodTemplate from flytekit.tools.translator import get_command_prefix_for_fast_execute class MPIJobModel(_common.FlyteIdlEntity): From c13a713f89614b509529e7e641c84c0283438989 Mon Sep 17 00:00:00 2001 From: byhsu Date: Sun, 2 Apr 2023 22:32:24 -0700 Subject: [PATCH 3/6] add test Signed-off-by: byhsu --- .../flytekitplugins/kfmpi/__init__.py | 1 + .../flytekitplugins/kfmpi/task.py | 2 -- .../flytekit-kf-mpi/tests/test_mpi_task.py | 23 ++++++++++++++++++- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py index 6920c34e84..6bedc5604a 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py @@ -11,3 +11,4 @@ """ from .task import MPIJob +from .task import HorovodJob diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 6befb9d667..7665b40d3f 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -12,7 +12,6 @@ from flytekit.configuration import SerializationSettings from flytekit.extend import TaskPlugins from flytekit.models import common as _common -from flytekit.tools.translator import get_command_prefix_for_fast_execute class MPIJobModel(_common.FlyteIdlEntity): """Model definition for MPI the plugin @@ -146,7 +145,6 @@ class HorovodFunctionTask(PythonFunctionTask[HorovodJob]): # Customize your setup here. Please ensure the cmd, path, volume, etc are available in the pod. ssh_command = "/usr/sbin/sshd -De -f /home/jobuser/.sshd_config" - ssh_auth_mount_path = "/home/jobuser/.ssh" discovery_script_path = "/etc/mpi/discover_hosts.sh" def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index ebb0c49b58..86aa5a5b9c 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -1,4 +1,4 @@ -from flytekitplugins.kfmpi.task import MPIJob, MPIJobModel +from flytekitplugins.kfmpi.task import MPIJob, HorovodJob, MPIJobModel from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings @@ -41,3 +41,24 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} assert my_mpi_task.task_type == "mpi" + +def test_horovod_task(): + @task( + task_config=HorovodJob(num_workers=5, num_launcher_replicas=5, slots=1), + ) + def my_horovod_task(): + ... + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + cmd = my_horovod_task.get_command(settings) + assert "horovodrun" in cmd + config = my_horovod_task.get_config(settings) + assert "/usr/sbin/sshd" in config["worker_spec_command"] + From 7f39117f95c74a58fa6c3b7f39ec75a18a8d0ae3 Mon Sep 17 00:00:00 2001 From: byhsu Date: Mon, 3 Apr 2023 22:37:53 -0700 Subject: [PATCH 4/6] inherit from mpifunctiontask Signed-off-by: byhsu --- .../flytekitplugins/kfmpi/task.py | 16 +++------------- plugins/flytekit-kf-mpi/tests/test_mpi_task.py | 3 ++- 2 files changed, 5 insertions(+), 14 deletions(-) diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 7665b40d3f..fcf7c13548 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -137,22 +137,20 @@ class HorovodJob(object): num_launcher_replicas: int = 1 num_workers: int = 1 -class HorovodFunctionTask(PythonFunctionTask[HorovodJob]): +class HorovodFunctionTask(MPIFunctionTask): """ For more info, check out https://github.com/horovod/horovod """ - _MPI_JOB_TASK_TYPE = "mpi" # Customize your setup here. Please ensure the cmd, path, volume, etc are available in the pod. ssh_command = "/usr/sbin/sshd -De -f /home/jobuser/.sshd_config" discovery_script_path = "/etc/mpi/discover_hosts.sh" - def __init__(self, task_config: MPIJob, task_function: Callable, **kwargs): + def __init__(self, task_config: HorovodJob, task_function: Callable, **kwargs): super().__init__( task_config=task_config, task_function=task_function, - task_type=self._MPI_JOB_TASK_TYPE, **kwargs, ) @@ -160,14 +158,6 @@ def get_command(self, settings: SerializationSettings) -> List[str]: cmd = super().get_command(settings) mpi_cmd = self._get_horovod_prefix() + cmd return mpi_cmd - - def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: - job = MPIJobModel( - num_workers=self.task_config.num_workers, - num_launcher_replicas=self.task_config.num_launcher_replicas, - slots=self.task_config.slots, - ) - return MessageToDict(job.to_flyte_idl()) def get_config(self, settings: SerializationSettings) -> Dict[str, str]: config = super().get_config(settings) @@ -193,7 +183,7 @@ def _get_horovod_prefix(self) -> List[str]: "--host-discovery-script", self.discovery_script_path, ] - return base_cmd + return base_cmd # Register the MPI Plugin into the flytekit core plugin system TaskPlugins.register_pythontask_plugin(MPIJob, MPIFunctionTask) diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 86aa5a5b9c..6d30dd6c92 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -61,4 +61,5 @@ def my_horovod_task(): assert "horovodrun" in cmd config = my_horovod_task.get_config(settings) assert "/usr/sbin/sshd" in config["worker_spec_command"] - + custom = my_horovod_task.get_custom(settings) + assert isinstance(custom, dict) == True From 8126930c655f96b51596761ae598e513a59878ab Mon Sep 17 00:00:00 2001 From: byhsu Date: Wed, 5 Apr 2023 10:14:35 -0700 Subject: [PATCH 5/6] fix format Signed-off-by: byhsu --- .../flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py | 3 +-- plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py | 10 +++++++--- plugins/flytekit-kf-mpi/tests/test_mpi_task.py | 3 ++- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py index 6bedc5604a..df5c74288e 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py @@ -10,5 +10,4 @@ MPIJob """ -from .task import MPIJob -from .task import HorovodJob +from .task import HorovodJob, MPIJob diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index fcf7c13548..e1c1be0a03 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -13,6 +13,7 @@ from flytekit.extend import TaskPlugins from flytekit.models import common as _common + class MPIJobModel(_common.FlyteIdlEntity): """Model definition for MPI the plugin @@ -79,6 +80,7 @@ class MPIJob(object): num_launcher_replicas: int = 1 num_workers: int = 1 + class MPIFunctionTask(PythonFunctionTask[MPIJob]): """ Plugin that submits a MPIJob (see https://github.com/kubeflow/mpi-operator) @@ -137,6 +139,7 @@ class HorovodJob(object): num_launcher_replicas: int = 1 num_workers: int = 1 + class HorovodFunctionTask(MPIFunctionTask): """ For more info, check out https://github.com/horovod/horovod @@ -158,11 +161,11 @@ def get_command(self, settings: SerializationSettings) -> List[str]: cmd = super().get_command(settings) mpi_cmd = self._get_horovod_prefix() + cmd return mpi_cmd - + def get_config(self, settings: SerializationSettings) -> Dict[str, str]: config = super().get_config(settings) return {**config, "worker_spec_command": self.ssh_command} - + def _get_horovod_prefix(self) -> List[str]: np = self.task_config.num_workers * self.task_config.slots base_cmd = [ @@ -185,6 +188,7 @@ def _get_horovod_prefix(self) -> List[str]: ] return base_cmd + # Register the MPI Plugin into the flytekit core plugin system TaskPlugins.register_pythontask_plugin(MPIJob, MPIFunctionTask) -TaskPlugins.register_pythontask_plugin(HorovodJob, HorovodFunctionTask) \ No newline at end of file +TaskPlugins.register_pythontask_plugin(HorovodJob, HorovodFunctionTask) diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 6d30dd6c92..03cdeb4034 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -1,4 +1,4 @@ -from flytekitplugins.kfmpi.task import MPIJob, HorovodJob, MPIJobModel +from flytekitplugins.kfmpi.task import HorovodJob, MPIJob, MPIJobModel from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings @@ -42,6 +42,7 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} assert my_mpi_task.task_type == "mpi" + def test_horovod_task(): @task( task_config=HorovodJob(num_workers=5, num_launcher_replicas=5, slots=1), From 5f4a285a197cbcaf0110889c89723cc35a70a3c8 Mon Sep 17 00:00:00 2001 From: byhsu Date: Mon, 10 Apr 2023 17:38:26 -0700 Subject: [PATCH 6/6] fix format Signed-off-by: byhsu --- plugins/flytekit-kf-mpi/tests/test_mpi_task.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index 03cdeb4034..7732d520c2 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -63,4 +63,4 @@ def my_horovod_task(): config = my_horovod_task.get_config(settings) assert "/usr/sbin/sshd" in config["worker_spec_command"] custom = my_horovod_task.get_custom(settings) - assert isinstance(custom, dict) == True + assert isinstance(custom, dict) is True