Skip to content

Commit

Permalink
Add tests for new dispatch API, make old tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
MattToast committed Jul 20, 2024
1 parent 0b9ec1a commit 6d5a3c4
Show file tree
Hide file tree
Showing 16 changed files with 357 additions and 72 deletions.
3 changes: 2 additions & 1 deletion smartsim/settings/builders/launch/alps.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
from ..launchArgBuilder import LaunchArgBuilder

logger = get_logger(__name__)
_format_aprun_command = shell_format(run_command="aprun")


@dispatch(with_format=shell_format(run_command="aprun"), to_launcher=ShellLauncher)
@dispatch(with_format=_format_aprun_command, to_launcher=ShellLauncher)
class AprunArgBuilder(LaunchArgBuilder):
def _reserved_launch_args(self) -> set[str]:
"""Return reserved launch arguments."""
Expand Down
3 changes: 2 additions & 1 deletion smartsim/settings/builders/launch/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
from ..launchArgBuilder import LaunchArgBuilder

logger = get_logger(__name__)
_format_local_command = shell_format(run_command=None)


@dispatch(with_format=shell_format(run_command=None), to_launcher=ShellLauncher)
@dispatch(with_format=_format_local_command, to_launcher=ShellLauncher)
class LocalArgBuilder(LaunchArgBuilder):
def launcher_str(self) -> str:
"""Get the string representation of the launcher"""
Expand Down
3 changes: 2 additions & 1 deletion smartsim/settings/builders/launch/lsf.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
from ..launchArgBuilder import LaunchArgBuilder

logger = get_logger(__name__)
_format_jsrun_command = shell_format(run_command="jsrun")


@dispatch(with_format=shell_format(run_command="jsrun"), to_launcher=ShellLauncher)
@dispatch(with_format=_format_jsrun_command, to_launcher=ShellLauncher)
class JsrunArgBuilder(LaunchArgBuilder):
def launcher_str(self) -> str:
"""Get the string representation of the launcher"""
Expand Down
9 changes: 6 additions & 3 deletions smartsim/settings/builders/launch/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@
from ..launchArgBuilder import LaunchArgBuilder

logger = get_logger(__name__)
_format_mpirun_command = shell_format("mpirun")
_format_mpiexec_command = shell_format("mpiexec")
_format_orterun_command = shell_format("orterun")


class _BaseMPIArgBuilder(LaunchArgBuilder):
Expand Down Expand Up @@ -215,21 +218,21 @@ def set(self, key: str, value: str | None) -> None:
self._launch_args[key] = value


@dispatch(with_format=shell_format(run_command="mpirun"), to_launcher=ShellLauncher)
@dispatch(with_format=_format_mpirun_command, to_launcher=ShellLauncher)
class MpiArgBuilder(_BaseMPIArgBuilder):
def launcher_str(self) -> str:
"""Get the string representation of the launcher"""
return LauncherType.Mpirun.value


@dispatch(with_format=shell_format(run_command="mpiexec"), to_launcher=ShellLauncher)
@dispatch(with_format=_format_mpiexec_command, to_launcher=ShellLauncher)
class MpiexecArgBuilder(_BaseMPIArgBuilder):
def launcher_str(self) -> str:
"""Get the string representation of the launcher"""
return LauncherType.Mpiexec.value


@dispatch(with_format=shell_format(run_command="orterun"), to_launcher=ShellLauncher)
@dispatch(with_format=_format_orterun_command, to_launcher=ShellLauncher)
class OrteArgBuilder(_BaseMPIArgBuilder):
def launcher_str(self) -> str:
"""Get the string representation of the launcher"""
Expand Down
3 changes: 2 additions & 1 deletion smartsim/settings/builders/launch/pals.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,10 @@
from ..launchArgBuilder import LaunchArgBuilder

logger = get_logger(__name__)
_format_mpiexec_command = shell_format(run_command="mpiexec")


@dispatch(with_format=shell_format(run_command="mpiexec"), to_launcher=ShellLauncher)
@dispatch(with_format=_format_mpiexec_command, to_launcher=ShellLauncher)
class PalsMpiexecArgBuilder(LaunchArgBuilder):
def launcher_str(self) -> str:
"""Get the string representation of the launcher"""
Expand Down
3 changes: 2 additions & 1 deletion smartsim/settings/builders/launch/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,10 @@
from ..launchArgBuilder import LaunchArgBuilder

logger = get_logger(__name__)
_format_srun_command = shell_format(run_command="srun")


@dispatch(with_format=shell_format(run_command="srun"), to_launcher=ShellLauncher)
@dispatch(with_format=_format_srun_command, to_launcher=ShellLauncher)
class SlurmArgBuilder(LaunchArgBuilder):
def launcher_str(self) -> str:
"""Get the string representation of the launcher"""
Expand Down
9 changes: 2 additions & 7 deletions tests/temp_tests/test_settings/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from unittest.mock import Mock

import pytest

from smartsim.settings import dispatch
Expand All @@ -34,7 +32,7 @@

@pytest.fixture
def echo_executable_like():
class _ExeLike(launch.ExecutableLike):
class _ExeLike(dispatch.ExecutableLike):
def as_program_arguments(self):
return ("echo", "hello", "world")

Expand All @@ -44,13 +42,10 @@ def as_program_arguments(self):
@pytest.fixture
def settings_builder():
class _SettingsBuilder(launch.LaunchArgBuilder):
def set(self, arg, val): ...
def launcher_str(self):
return "Mock Settings Builder"

def set(self, arg, val): ...
def finalize(self, exe, env):
return Mock()

yield _SettingsBuilder({})


Expand Down
7 changes: 5 additions & 2 deletions tests/temp_tests/test_settings/test_alpsLauncher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pytest

from smartsim.settings import LaunchSettings
from smartsim.settings.builders.launch.alps import AprunArgBuilder
from smartsim.settings.builders.launch.alps import (
AprunArgBuilder,
_format_aprun_command,
)
from smartsim.settings.launchCommand import LauncherType

pytestmark = pytest.mark.group_a
Expand Down Expand Up @@ -183,5 +186,5 @@ def test_invalid_exclude_hostlist_format():
),
)
def test_formatting_launch_args(echo_executable_like, args, expected):
cmd = AprunArgBuilder(args).finalize(echo_executable_like, {})
cmd = _format_aprun_command(AprunArgBuilder(args), echo_executable_like, {})
assert tuple(cmd) == expected
Loading

0 comments on commit 6d5a3c4

Please sign in to comment.