diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py index 6920c34e84..df5c74288e 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/__init__.py @@ -10,4 +10,4 @@ MPIJob """ -from .task import MPIJob +from .task import HorovodJob, MPIJob diff --git a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py index 6f207b421d..e1c1be0a03 100644 --- a/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py +++ b/plugins/flytekit-kf-mpi/flytekitplugins/kfmpi/task.py @@ -133,5 +133,62 @@ def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]: return MessageToDict(job.to_flyte_idl()) +@dataclass +class HorovodJob(object): + slots: int + num_launcher_replicas: int = 1 + num_workers: int = 1 + + +class HorovodFunctionTask(MPIFunctionTask): + """ + For more info, check out https://github.com/horovod/horovod + """ + + # Customize your setup here. Please ensure the cmd, path, volume, etc are available in the pod. + ssh_command = "/usr/sbin/sshd -De -f /home/jobuser/.sshd_config" + discovery_script_path = "/etc/mpi/discover_hosts.sh" + + def __init__(self, task_config: HorovodJob, task_function: Callable, **kwargs): + + super().__init__( + task_config=task_config, + task_function=task_function, + **kwargs, + ) + + def get_command(self, settings: SerializationSettings) -> List[str]: + cmd = super().get_command(settings) + mpi_cmd = self._get_horovod_prefix() + cmd + return mpi_cmd + + def get_config(self, settings: SerializationSettings) -> Dict[str, str]: + config = super().get_config(settings) + return {**config, "worker_spec_command": self.ssh_command} + + def _get_horovod_prefix(self) -> List[str]: + np = self.task_config.num_workers * self.task_config.slots + base_cmd = [ + "horovodrun", + "-np", + f"{np}", + "--verbose", + "--log-level", + "INFO", + "--network-interface", + "eth0", + "--min-np", + f"{np}", + "--max-np", + f"{np}", + "--slots-per-host", + f"{self.task_config.slots}", + "--host-discovery-script", + self.discovery_script_path, + ] + return base_cmd + + # Register the MPI Plugin into the flytekit core plugin system TaskPlugins.register_pythontask_plugin(MPIJob, MPIFunctionTask) +TaskPlugins.register_pythontask_plugin(HorovodJob, HorovodFunctionTask) diff --git a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py index ebb0c49b58..7732d520c2 100644 --- a/plugins/flytekit-kf-mpi/tests/test_mpi_task.py +++ b/plugins/flytekit-kf-mpi/tests/test_mpi_task.py @@ -1,4 +1,4 @@ -from flytekitplugins.kfmpi.task import MPIJob, MPIJobModel +from flytekitplugins.kfmpi.task import HorovodJob, MPIJob, MPIJobModel from flytekit import Resources, task from flytekit.configuration import Image, ImageConfig, SerializationSettings @@ -41,3 +41,26 @@ def my_mpi_task(x: int, y: str) -> int: assert my_mpi_task.get_custom(settings) == {"numLauncherReplicas": 10, "numWorkers": 10, "slots": 1} assert my_mpi_task.task_type == "mpi" + + +def test_horovod_task(): + @task( + task_config=HorovodJob(num_workers=5, num_launcher_replicas=5, slots=1), + ) + def my_horovod_task(): + ... + + default_img = Image(name="default", fqn="test", tag="tag") + settings = SerializationSettings( + project="project", + domain="domain", + version="version", + env={"FOO": "baz"}, + image_config=ImageConfig(default_image=default_img, images=[default_img]), + ) + cmd = my_horovod_task.get_command(settings) + assert "horovodrun" in cmd + config = my_horovod_task.get_config(settings) + assert "/usr/sbin/sshd" in config["worker_spec_command"] + custom = my_horovod_task.get_custom(settings) + assert isinstance(custom, dict) is True