diff --git a/smartsim/settings/builders/launch/alps.py b/smartsim/settings/builders/launch/alps.py index f325777e4..09d5931ac 100644 --- a/smartsim/settings/builders/launch/alps.py +++ b/smartsim/settings/builders/launch/alps.py @@ -36,9 +36,10 @@ from ..launchArgBuilder import LaunchArgBuilder logger = get_logger(__name__) +_format_aprun_command = shell_format(run_command="aprun") -@dispatch(with_format=shell_format(run_command="aprun"), to_launcher=ShellLauncher) +@dispatch(with_format=_format_aprun_command, to_launcher=ShellLauncher) class AprunArgBuilder(LaunchArgBuilder): def _reserved_launch_args(self) -> set[str]: """Return reserved launch arguments.""" diff --git a/smartsim/settings/builders/launch/local.py b/smartsim/settings/builders/launch/local.py index 21fb71c8a..7002a6831 100644 --- a/smartsim/settings/builders/launch/local.py +++ b/smartsim/settings/builders/launch/local.py @@ -36,9 +36,10 @@ from ..launchArgBuilder import LaunchArgBuilder logger = get_logger(__name__) +_format_local_command = shell_format(run_command=None) -@dispatch(with_format=shell_format(run_command=None), to_launcher=ShellLauncher) +@dispatch(with_format=_format_local_command, to_launcher=ShellLauncher) class LocalArgBuilder(LaunchArgBuilder): def launcher_str(self) -> str: """Get the string representation of the launcher""" diff --git a/smartsim/settings/builders/launch/lsf.py b/smartsim/settings/builders/launch/lsf.py index 13a32fd73..ec99d51b9 100644 --- a/smartsim/settings/builders/launch/lsf.py +++ b/smartsim/settings/builders/launch/lsf.py @@ -36,9 +36,10 @@ from ..launchArgBuilder import LaunchArgBuilder logger = get_logger(__name__) +_format_jsrun_command = shell_format(run_command="jsrun") -@dispatch(with_format=shell_format(run_command="jsrun"), to_launcher=ShellLauncher) +@dispatch(with_format=_format_jsrun_command, to_launcher=ShellLauncher) class JsrunArgBuilder(LaunchArgBuilder): def launcher_str(self) -> str: """Get the string representation of the launcher""" diff --git a/smartsim/settings/builders/launch/mpi.py b/smartsim/settings/builders/launch/mpi.py index ea24564da..139096010 100644 --- a/smartsim/settings/builders/launch/mpi.py +++ b/smartsim/settings/builders/launch/mpi.py @@ -36,6 +36,9 @@ from ..launchArgBuilder import LaunchArgBuilder logger = get_logger(__name__) +_format_mpirun_command = shell_format("mpirun") +_format_mpiexec_command = shell_format("mpiexec") +_format_orterun_command = shell_format("orterun") class _BaseMPIArgBuilder(LaunchArgBuilder): @@ -215,21 +218,21 @@ def set(self, key: str, value: str | None) -> None: self._launch_args[key] = value -@dispatch(with_format=shell_format(run_command="mpirun"), to_launcher=ShellLauncher) +@dispatch(with_format=_format_mpirun_command, to_launcher=ShellLauncher) class MpiArgBuilder(_BaseMPIArgBuilder): def launcher_str(self) -> str: """Get the string representation of the launcher""" return LauncherType.Mpirun.value -@dispatch(with_format=shell_format(run_command="mpiexec"), to_launcher=ShellLauncher) +@dispatch(with_format=_format_mpiexec_command, to_launcher=ShellLauncher) class MpiexecArgBuilder(_BaseMPIArgBuilder): def launcher_str(self) -> str: """Get the string representation of the launcher""" return LauncherType.Mpiexec.value -@dispatch(with_format=shell_format(run_command="orterun"), to_launcher=ShellLauncher) +@dispatch(with_format=_format_orterun_command, to_launcher=ShellLauncher) class OrteArgBuilder(_BaseMPIArgBuilder): def launcher_str(self) -> str: """Get the string representation of the launcher""" diff --git a/smartsim/settings/builders/launch/pals.py b/smartsim/settings/builders/launch/pals.py index 4f2155c1f..1e7ed814e 100644 --- a/smartsim/settings/builders/launch/pals.py +++ b/smartsim/settings/builders/launch/pals.py @@ -36,9 +36,10 @@ from ..launchArgBuilder import LaunchArgBuilder logger = get_logger(__name__) +_format_mpiexec_command = shell_format(run_command="mpiexec") -@dispatch(with_format=shell_format(run_command="mpiexec"), to_launcher=ShellLauncher) +@dispatch(with_format=_format_mpiexec_command, to_launcher=ShellLauncher) class PalsMpiexecArgBuilder(LaunchArgBuilder): def launcher_str(self) -> str: """Get the string representation of the launcher""" diff --git a/smartsim/settings/builders/launch/slurm.py b/smartsim/settings/builders/launch/slurm.py index 907a6da6c..72058f983 100644 --- a/smartsim/settings/builders/launch/slurm.py +++ b/smartsim/settings/builders/launch/slurm.py @@ -38,9 +38,10 @@ from ..launchArgBuilder import LaunchArgBuilder logger = get_logger(__name__) +_format_srun_command = shell_format(run_command="srun") -@dispatch(with_format=shell_format(run_command="srun"), to_launcher=ShellLauncher) +@dispatch(with_format=_format_srun_command, to_launcher=ShellLauncher) class SlurmArgBuilder(LaunchArgBuilder): def launcher_str(self) -> str: """Get the string representation of the launcher""" diff --git a/tests/temp_tests/test_settings/conftest.py b/tests/temp_tests/test_settings/conftest.py index ebf361e97..72061264f 100644 --- a/tests/temp_tests/test_settings/conftest.py +++ b/tests/temp_tests/test_settings/conftest.py @@ -24,8 +24,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -from unittest.mock import Mock - import pytest from smartsim.settings import dispatch @@ -34,7 +32,7 @@ @pytest.fixture def echo_executable_like(): - class _ExeLike(launch.ExecutableLike): + class _ExeLike(dispatch.ExecutableLike): def as_program_arguments(self): return ("echo", "hello", "world") @@ -44,13 +42,10 @@ def as_program_arguments(self): @pytest.fixture def settings_builder(): class _SettingsBuilder(launch.LaunchArgBuilder): + def set(self, arg, val): ... def launcher_str(self): return "Mock Settings Builder" - def set(self, arg, val): ... - def finalize(self, exe, env): - return Mock() - yield _SettingsBuilder({}) diff --git a/tests/temp_tests/test_settings/test_alpsLauncher.py b/tests/temp_tests/test_settings/test_alpsLauncher.py index 7fa95cb6d..5ac2f8e11 100644 --- a/tests/temp_tests/test_settings/test_alpsLauncher.py +++ b/tests/temp_tests/test_settings/test_alpsLauncher.py @@ -1,7 +1,10 @@ import pytest from smartsim.settings import LaunchSettings -from smartsim.settings.builders.launch.alps import AprunArgBuilder +from smartsim.settings.builders.launch.alps import ( + AprunArgBuilder, + _format_aprun_command, +) from smartsim.settings.launchCommand import LauncherType pytestmark = pytest.mark.group_a @@ -183,5 +186,5 @@ def test_invalid_exclude_hostlist_format(): ), ) def test_formatting_launch_args(echo_executable_like, args, expected): - cmd = AprunArgBuilder(args).finalize(echo_executable_like, {}) + cmd = _format_aprun_command(AprunArgBuilder(args), echo_executable_like, {}) assert tuple(cmd) == expected diff --git a/tests/temp_tests/test_settings/test_dispatch.py b/tests/temp_tests/test_settings/test_dispatch.py index ccd1e81cd..78c44ad54 100644 --- a/tests/temp_tests/test_settings/test_dispatch.py +++ b/tests/temp_tests/test_settings/test_dispatch.py @@ -24,31 +24,56 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import abc import contextlib +import dataclasses +import io import pytest +from smartsim.error import errors from smartsim.settings import dispatch pytestmark = pytest.mark.group_a +FORMATTED = object() -def test_declaritive_form_dispatch_declaration(launcher_like, settings_builder): + +def format_fn(args, exe, env): + return FORMATTED + + +@pytest.fixture +def expected_dispatch_registry(launcher_like, settings_builder): + yield { + type(settings_builder): dispatch._DispatchRegistration( + format_fn, type(launcher_like) + ) + } + + +def test_declaritive_form_dispatch_declaration( + launcher_like, settings_builder, expected_dispatch_registry +): d = dispatch.Dispatcher() - assert type(settings_builder) == d.dispatch(to_launcher=type(launcher_like))( - type(settings_builder) - ) - assert d._dispatch_registry == {type(settings_builder): type(launcher_like)} + assert type(settings_builder) == d.dispatch( + with_format=format_fn, to_launcher=type(launcher_like) + )(type(settings_builder)) + assert d._dispatch_registry == expected_dispatch_registry -def test_imperative_form_dispatch_declaration(launcher_like, settings_builder): +def test_imperative_form_dispatch_declaration( + launcher_like, settings_builder, expected_dispatch_registry +): d = dispatch.Dispatcher() - assert None == d.dispatch(type(settings_builder), to_launcher=type(launcher_like)) - assert d._dispatch_registry == {type(settings_builder): type(launcher_like)} + assert None == d.dispatch( + type(settings_builder), to_launcher=type(launcher_like), with_format=format_fn + ) + assert d._dispatch_registry == expected_dispatch_registry def test_dispatchers_from_same_registry_do_not_cross_polute( - launcher_like, settings_builder + launcher_like, settings_builder, expected_dispatch_registry ): some_starting_registry = {} d1 = dispatch.Dispatcher(dispatch_registry=some_starting_registry) @@ -60,12 +85,16 @@ def test_dispatchers_from_same_registry_do_not_cross_polute( d1._dispatch_registry is not d2._dispatch_registry is not some_starting_registry ) - d2.dispatch(type(settings_builder), to_launcher=type(launcher_like)) + d2.dispatch( + type(settings_builder), with_format=format_fn, to_launcher=type(launcher_like) + ) assert d1._dispatch_registry == {} - assert d2._dispatch_registry == {type(settings_builder): type(launcher_like)} + assert d2._dispatch_registry == expected_dispatch_registry -def test_copied_dispatchers_do_not_cross_pollute(launcher_like, settings_builder): +def test_copied_dispatchers_do_not_cross_pollute( + launcher_like, settings_builder, expected_dispatch_registry +): some_starting_registry = {} d1 = dispatch.Dispatcher(dispatch_registry=some_starting_registry) d2 = d1.copy() @@ -76,70 +105,304 @@ def test_copied_dispatchers_do_not_cross_pollute(launcher_like, settings_builder d1._dispatch_registry is not d2._dispatch_registry is not some_starting_registry ) - d2.dispatch(type(settings_builder), to_launcher=type(launcher_like)) + d2.dispatch( + type(settings_builder), to_launcher=type(launcher_like), with_format=format_fn + ) assert d1._dispatch_registry == {} - assert d2._dispatch_registry == {type(settings_builder): type(launcher_like)} + assert d2._dispatch_registry == expected_dispatch_registry @pytest.mark.parametrize( "add_dispatch, expected_ctx", ( pytest.param( - lambda d, s, l: d.dispatch(s, to_launcher=l), + lambda d, s, l: d.dispatch(s, to_launcher=l, with_format=format_fn), pytest.raises(TypeError, match="has already been registered"), id="Imperative -- Disallowed implicitly", ), pytest.param( - lambda d, s, l: d.dispatch(s, to_launcher=l, allow_overwrite=True), + lambda d, s, l: d.dispatch( + s, to_launcher=l, with_format=format_fn, allow_overwrite=True + ), contextlib.nullcontext(), id="Imperative -- Allowed with flag", ), pytest.param( - lambda d, s, l: d.dispatch(to_launcher=l)(s), + lambda d, s, l: d.dispatch(to_launcher=l, with_format=format_fn)(s), pytest.raises(TypeError, match="has already been registered"), id="Declarative -- Disallowed implicitly", ), pytest.param( - lambda d, s, l: d.dispatch(to_launcher=l, allow_overwrite=True)(s), + lambda d, s, l: d.dispatch( + to_launcher=l, with_format=format_fn, allow_overwrite=True + )(s), contextlib.nullcontext(), id="Declarative -- Allowed with flag", ), ), ) def test_dispatch_overwriting( - add_dispatch, expected_ctx, launcher_like, settings_builder + add_dispatch, + expected_ctx, + launcher_like, + settings_builder, + expected_dispatch_registry, ): - registry = {type(settings_builder): type(launcher_like)} - d = dispatch.Dispatcher(dispatch_registry=registry) + d = dispatch.Dispatcher(dispatch_registry=expected_dispatch_registry) with expected_ctx: add_dispatch(d, type(settings_builder), type(launcher_like)) @pytest.mark.parametrize( - "map_settings", + "type_or_instance", ( - pytest.param(type, id="From settings type"), - pytest.param(lambda s: s, id="From settings instance"), + pytest.param(type, id="type"), + pytest.param(lambda x: x, id="instance"), ), ) -def test_dispatch_can_retrieve_launcher_to_dispatch_to( - map_settings, launcher_like, settings_builder +def test_dispatch_can_retrieve_dispatch_info_from_dispatch_registry( + expected_dispatch_registry, launcher_like, settings_builder, type_or_instance ): - registry = {type(settings_builder): type(launcher_like)} - d = dispatch.Dispatcher(dispatch_registry=registry) - assert type(launcher_like) == d.get_launcher_for(map_settings(settings_builder)) + d = dispatch.Dispatcher(dispatch_registry=expected_dispatch_registry) + assert dispatch._DispatchRegistration( + format_fn, type(launcher_like) + ) == d.get_dispatch(type_or_instance(settings_builder)) @pytest.mark.parametrize( - "map_settings", + "type_or_instance", ( - pytest.param(type, id="From settings type"), - pytest.param(lambda s: s, id="From settings instance"), + pytest.param(type, id="type"), + pytest.param(lambda x: x, id="instance"), ), ) def test_dispatch_raises_if_settings_type_not_registered( - map_settings, launcher_like, settings_builder + settings_builder, type_or_instance ): d = dispatch.Dispatcher(dispatch_registry={}) - with pytest.raises(TypeError, match="no launcher type to dispatch to"): - d.get_launcher_for(map_settings(settings_builder)) + with pytest.raises( + TypeError, match="No dispatch for `.+?(?=`)` has been registered" + ): + d.get_dispatch(type_or_instance(settings_builder)) + + +class LauncherABC(abc.ABC): + @abc.abstractmethod + def start(self, launchable): ... + @classmethod + @abc.abstractmethod + def create(cls, exp): ... + + +class PartImplLauncherABC(LauncherABC): + def start(self, launchable): + return dispatch.create_job_id() + + +class FullImplLauncherABC(PartImplLauncherABC): + @classmethod + def create(cls, exp): + return cls() + + +@pytest.mark.parametrize( + "cls, ctx", + ( + pytest.param( + dispatch.LauncherLike, + pytest.raises(TypeError, match="Cannot dispatch to protocol"), + id="Cannot dispatch to protocol class", + ), + pytest.param( + "launcher_like", + contextlib.nullcontext(None), + id="Can dispatch to protocol implementation", + ), + pytest.param( + LauncherABC, + pytest.raises(TypeError, match="Cannot dispatch to abstract class"), + id="Cannot dispatch to abstract class", + ), + pytest.param( + PartImplLauncherABC, + pytest.raises(TypeError, match="Cannot dispatch to abstract class"), + id="Cannot dispatch to partially implemented abstract class", + ), + pytest.param( + FullImplLauncherABC, + contextlib.nullcontext(None), + id="Can dispatch to fully implemented abstract class", + ), + ), +) +def test_register_dispatch_to_launcher_types(request, cls, ctx): + if isinstance(cls, str): + cls = request.getfixturevalue(cls) + d = dispatch.Dispatcher() + with ctx: + d.dispatch(to_launcher=cls, with_format=format_fn) + + +@dataclasses.dataclass +class BufferWriterLauncher(dispatch.LauncherLike[list[str]]): + buf: io.StringIO + + @classmethod + def create(cls, exp): + return cls(io.StringIO()) + + def start(self, strs): + self.buf.writelines(f"{s}\n" for s in strs) + return dispatch.create_job_id() + + +class BufferWriterLauncherSubclass(BufferWriterLauncher): ... + + +@pytest.fixture +def buffer_writer_dispatch(): + stub_format_fn = lambda *a, **kw: ["some", "strings"] + return dispatch._DispatchRegistration(stub_format_fn, BufferWriterLauncher) + + +@pytest.mark.parametrize( + "input_, map_, expected", + ( + pytest.param( + ["list", "of", "strings"], + lambda xs: xs, + ["list\n", "of\n", "strings\n"], + id="[str] -> [str]", + ), + pytest.param( + "words on new lines", + lambda x: x.split(), + ["words\n", "on\n", "new\n", "lines\n"], + id="str -> [str]", + ), + pytest.param( + range(1, 4), + lambda xs: [str(x) for x in xs], + ["1\n", "2\n", "3\n"], + id="[int] -> [str]", + ), + ), +) +def test_launcher_adapter_correctly_adapts_input_to_launcher(input_, map_, expected): + buf = io.StringIO() + adapter = dispatch._LauncherAdapter(BufferWriterLauncher(buf), map_) + adapter.start(input_) + buf.seek(0) + assert buf.readlines() == expected + + +@pytest.mark.parametrize( + "launcher_instance, ctx", + ( + pytest.param( + BufferWriterLauncher(io.StringIO()), + contextlib.nullcontext(None), + id="Correctly configures expected launcher", + ), + pytest.param( + BufferWriterLauncherSubclass(io.StringIO()), + pytest.raises( + TypeError, + match="^Cannot create launcher adapter.*expected launcher of type .+$", + ), + id="Errors if launcher types are disparate", + ), + pytest.param( + "launcher_like", + pytest.raises( + TypeError, + match="^Cannot create launcher adapter.*expected launcher of type .+$", + ), + id="Errors if types are not an exact match", + ), + ), +) +def test_dispatch_registration_can_configure_adapter_for_existing_launcher_instance( + request, settings_builder, buffer_writer_dispatch, launcher_instance, ctx +): + if isinstance(launcher_instance, str): + launcher_instance = request.getfixturevalue(launcher_instance) + with ctx: + adapter = buffer_writer_dispatch.create_adapter_from_launcher( + launcher_instance, settings_builder + ) + assert adapter._adapted_launcher is launcher_instance + + +@pytest.mark.parametrize( + "launcher_instances, ctx", + ( + pytest.param( + (BufferWriterLauncher(io.StringIO()),), + contextlib.nullcontext(None), + id="Correctly configures expected launcher", + ), + pytest.param( + ( + "launcher_like", + "launcher_like", + BufferWriterLauncher(io.StringIO()), + "launcher_like", + ), + contextlib.nullcontext(None), + id="Correctly ignores incompatible launchers instances", + ), + pytest.param( + (), + pytest.raises( + errors.LauncherNotFoundError, + match="^No launcher of exactly type.+could be found from provided launchers$", + ), + id="Errors if no launcher could be found", + ), + pytest.param( + ( + "launcher_like", + BufferWriterLauncherSubclass(io.StringIO), + "launcher_like", + ), + pytest.raises( + errors.LauncherNotFoundError, + match="^No launcher of exactly type.+could be found from provided launchers$", + ), + id="Errors if no launcher matches expected type exactly", + ), + ), +) +def test_dispatch_registration_configures_first_compatible_launcher_from_sequence_of_launchers( + request, settings_builder, buffer_writer_dispatch, launcher_instances, ctx +): + def resolve_instance(inst): + return request.getfixturevalue(inst) if isinstance(inst, str) else inst + + launcher_instances = tuple(map(resolve_instance, launcher_instances)) + + with ctx: + adapter = buffer_writer_dispatch.configure_first_compatible_launcher( + with_settings=settings_builder, from_available_launchers=launcher_instances + ) + + +def test_dispatch_registration_can_create_a_laucher_for_an_experiment_and_can_reconfigure_it_later( + settings_builder, buffer_writer_dispatch +): + class MockExperiment: ... + + exp = MockExperiment() + adapter_1 = buffer_writer_dispatch.create_new_launcher_configuration( + for_experiment=exp, with_settings=settings_builder + ) + assert type(adapter_1._adapted_launcher) == buffer_writer_dispatch.launcher_type + existing_launcher = adapter_1._adapted_launcher + + adapter_2 = buffer_writer_dispatch.create_adapter_from_launcher( + existing_launcher, settings_builder + ) + assert type(adapter_2._adapted_launcher) == buffer_writer_dispatch.launcher_type + assert adapter_1._adapted_launcher is adapter_2._adapted_launcher + assert adapter_1 is not adapter_2 diff --git a/tests/temp_tests/test_settings/test_dragonLauncher.py b/tests/temp_tests/test_settings/test_dragonLauncher.py index 004090eef..57ae67d68 100644 --- a/tests/temp_tests/test_settings/test_dragonLauncher.py +++ b/tests/temp_tests/test_settings/test_dragonLauncher.py @@ -1,5 +1,6 @@ import pytest +from smartsim._core.launcher.dragon.dragonLauncher import _as_run_request_view from smartsim._core.schemas.dragonRequests import DragonRunRequest from smartsim.settings import LaunchSettings from smartsim.settings.builders.launch.dragon import DragonArgBuilder @@ -38,12 +39,12 @@ def test_dragon_class_methods(function, value, flag, result): def test_formatting_launch_args_into_request( echo_executable_like, nodes, tasks_per_node ): - builder = DragonArgBuilder({}) + args = DragonArgBuilder({}) if nodes is not NOT_SET: - builder.set_nodes(nodes) + args.set_nodes(nodes) if tasks_per_node is not NOT_SET: - builder.set_tasks_per_node(tasks_per_node) - req = builder.finalize(echo_executable_like, {}) + args.set_tasks_per_node(tasks_per_node) + req = _as_run_request_view(args, echo_executable_like, {}) args = dict( (k, v) diff --git a/tests/temp_tests/test_settings/test_localLauncher.py b/tests/temp_tests/test_settings/test_localLauncher.py index 4eb314a8b..d69657f23 100644 --- a/tests/temp_tests/test_settings/test_localLauncher.py +++ b/tests/temp_tests/test_settings/test_localLauncher.py @@ -1,7 +1,10 @@ import pytest from smartsim.settings import LaunchSettings -from smartsim.settings.builders.launch.local import LocalArgBuilder +from smartsim.settings.builders.launch.local import ( + LocalArgBuilder, + _format_local_command, +) from smartsim.settings.launchCommand import LauncherType pytestmark = pytest.mark.group_a @@ -115,5 +118,5 @@ def test_format_env_vars(): def test_formatting_returns_original_exe(echo_executable_like): - cmd = LocalArgBuilder({}).finalize(echo_executable_like, {}) + cmd = _format_local_command(LocalArgBuilder({}), echo_executable_like, {}) assert tuple(cmd) == ("echo", "hello", "world") diff --git a/tests/temp_tests/test_settings/test_lsfLauncher.py b/tests/temp_tests/test_settings/test_lsfLauncher.py index 592c80ce7..91f65821b 100644 --- a/tests/temp_tests/test_settings/test_lsfLauncher.py +++ b/tests/temp_tests/test_settings/test_lsfLauncher.py @@ -1,7 +1,7 @@ import pytest from smartsim.settings import LaunchSettings -from smartsim.settings.builders.launch.lsf import JsrunArgBuilder +from smartsim.settings.builders.launch.lsf import JsrunArgBuilder, _format_jsrun_command from smartsim.settings.launchCommand import LauncherType pytestmark = pytest.mark.group_a @@ -92,5 +92,5 @@ def test_launch_args(): ), ) def test_formatting_launch_args(echo_executable_like, args, expected): - cmd = JsrunArgBuilder(args).finalize(echo_executable_like, {}) + cmd = _format_jsrun_command(JsrunArgBuilder(args), echo_executable_like, {}) assert tuple(cmd) == expected diff --git a/tests/temp_tests/test_settings/test_mpiLauncher.py b/tests/temp_tests/test_settings/test_mpiLauncher.py index 9b651c220..54fed657e 100644 --- a/tests/temp_tests/test_settings/test_mpiLauncher.py +++ b/tests/temp_tests/test_settings/test_mpiLauncher.py @@ -7,6 +7,9 @@ MpiArgBuilder, MpiexecArgBuilder, OrteArgBuilder, + _format_mpiexec_command, + _format_mpirun_command, + _format_orterun_command, ) from smartsim.settings.launchCommand import LauncherType @@ -210,11 +213,15 @@ def test_invalid_hostlist_format(launcher): @pytest.mark.parametrize( - "cls, cmd", + "cls, fmt, cmd", ( - pytest.param(MpiArgBuilder, "mpirun", id="w/ mpirun"), - pytest.param(MpiexecArgBuilder, "mpiexec", id="w/ mpiexec"), - pytest.param(OrteArgBuilder, "orterun", id="w/ orterun"), + pytest.param(MpiArgBuilder, _format_mpirun_command, "mpirun", id="w/ mpirun"), + pytest.param( + MpiexecArgBuilder, _format_mpiexec_command, "mpiexec", id="w/ mpiexec" + ), + pytest.param( + OrteArgBuilder, _format_orterun_command, "orterun", id="w/ orterun" + ), ), ) @pytest.mark.parametrize( @@ -248,6 +255,6 @@ def test_invalid_hostlist_format(launcher): ), ), ) -def test_formatting_launch_args(echo_executable_like, cls, cmd, args, expected): - fmt = cls(args).finalize(echo_executable_like, {}) - assert tuple(fmt) == (cmd,) + expected +def test_formatting_launch_args(echo_executable_like, cls, fmt, cmd, args, expected): + fmt_cmd = fmt(cls(args), echo_executable_like, {}) + assert tuple(fmt_cmd) == (cmd,) + expected diff --git a/tests/temp_tests/test_settings/test_palsLauncher.py b/tests/temp_tests/test_settings/test_palsLauncher.py index a0bc7821c..5b74e2d0c 100644 --- a/tests/temp_tests/test_settings/test_palsLauncher.py +++ b/tests/temp_tests/test_settings/test_palsLauncher.py @@ -1,7 +1,10 @@ import pytest from smartsim.settings import LaunchSettings -from smartsim.settings.builders.launch.pals import PalsMpiexecArgBuilder +from smartsim.settings.builders.launch.pals import ( + PalsMpiexecArgBuilder, + _format_mpiexec_command, +) from smartsim.settings.launchCommand import LauncherType pytestmark = pytest.mark.group_a @@ -103,5 +106,5 @@ def test_invalid_hostlist_format(): ), ) def test_formatting_launch_args(echo_executable_like, args, expected): - cmd = PalsMpiexecArgBuilder(args).finalize(echo_executable_like, {}) + cmd = _format_mpiexec_command(PalsMpiexecArgBuilder(args), echo_executable_like, {}) assert tuple(cmd) == expected diff --git a/tests/temp_tests/test_settings/test_slurmLauncher.py b/tests/temp_tests/test_settings/test_slurmLauncher.py index bfa7dd9e1..2a84c831e 100644 --- a/tests/temp_tests/test_settings/test_slurmLauncher.py +++ b/tests/temp_tests/test_settings/test_slurmLauncher.py @@ -1,7 +1,10 @@ import pytest from smartsim.settings import LaunchSettings -from smartsim.settings.builders.launch.slurm import SlurmArgBuilder +from smartsim.settings.builders.launch.slurm import ( + SlurmArgBuilder, + _format_srun_command, +) from smartsim.settings.launchCommand import LauncherType pytestmark = pytest.mark.group_a @@ -289,5 +292,5 @@ def test_set_het_groups(monkeypatch): ), ) def test_formatting_launch_args(echo_executable_like, args, expected): - cmd = SlurmArgBuilder(args).finalize(echo_executable_like, {}) + cmd = _format_srun_command(SlurmArgBuilder(args), echo_executable_like, {}) assert tuple(cmd) == expected diff --git a/tests/temp_tests/test_settings/test_slurmScheduler.py b/tests/temp_tests/test_settings/test_slurmScheduler.py index 0a34b6473..5c65d367a 100644 --- a/tests/temp_tests/test_settings/test_slurmScheduler.py +++ b/tests/temp_tests/test_settings/test_slurmScheduler.py @@ -105,6 +105,5 @@ def test_sbatch_manual(): slurmScheduler.scheduler_args.set_account("A3531") slurmScheduler.scheduler_args.set_walltime("10:00:00") formatted = slurmScheduler.format_batch_args() - print(f"here: {formatted}") result = ["--nodes=5", "--account=A3531", "--time=10:00:00"] assert formatted == result