diff --git a/conftest.py b/conftest.py index 4ddd99ce5..b0457522c 100644 --- a/conftest.py +++ b/conftest.py @@ -27,12 +27,15 @@ from __future__ import annotations import asyncio +from collections import defaultdict +from dataclasses import dataclass import json import os import pathlib import shutil import subprocess import signal +import socket import sys import tempfile import time @@ -40,6 +43,7 @@ import uuid import warnings from subprocess import run +import time import psutil import pytest @@ -53,7 +57,7 @@ from smartsim._core.utils.telemetry.telemetry import JobEntity from smartsim.database import Orchestrator from smartsim.entity import Model -from smartsim.error import SSConfigError +from smartsim.error import SSConfigError, SSInternalError from smartsim.log import get_logger from smartsim.settings import ( AprunSettings, @@ -78,7 +82,7 @@ test_num_gpus = CONFIG.test_num_gpus test_nic = CONFIG.test_interface test_alloc_specs_path = os.getenv("SMARTSIM_TEST_ALLOC_SPEC_SHEET_PATH", None) -test_port = CONFIG.test_port +test_ports = CONFIG.test_ports test_account = CONFIG.test_account or "" test_batch_resources: t.Dict[t.Any, t.Any] = CONFIG.test_batch_resources test_output_dirs = 0 @@ -89,7 +93,6 @@ test_hostlist = None has_aprun = shutil.which("aprun") is not None - def get_account() -> str: return test_account @@ -109,9 +112,7 @@ def print_test_configuration() -> None: print("TEST_ALLOC_SPEC_SHEET_PATH:", test_alloc_specs_path) print("TEST_DIR:", test_output_root) print("Test output will be located in TEST_DIR if there is a failure") - print( - "TEST_PORTS:", ", ".join(str(port) for port in range(test_port, test_port + 3)) - ) + print("TEST_PORTS:", ", ".join(str(port) for port in test_ports)) if test_batch_resources: print("TEST_BATCH_RESOURCES: ") print(json.dumps(test_batch_resources, indent=2)) @@ -297,7 +298,23 @@ def _reset(): ) -@pytest.fixture +def _find_free_port(ports: t.Collection[int]) -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + for port in ports: + try: + sock.bind(("127.0.0.1", port)) + except socket.error: + continue + else: + _, port_ = sock.getsockname() + return int(port_) + raise SSInternalError( + "Could not find a free port out of a options: " + f"{', '.join(str(port) for port in sorted(ports))}" + ) + + +@pytest.fixture(scope="session") def wlmutils() -> t.Type[WLMUtils]: return WLMUtils @@ -314,7 +331,9 @@ def get_test_launcher() -> str: @staticmethod def get_test_port() -> int: - return test_port + # TODO: Ideally this should find a free port on the correct host(s), + # but this is good enough for now + return _find_free_port(test_ports) @staticmethod def get_test_account() -> str: @@ -420,61 +439,6 @@ def get_run_settings( return RunSettings(exe, args) - @staticmethod - def get_orchestrator(nodes: int = 1, batch: bool = False) -> Orchestrator: - if test_launcher == "pbs": - if not shutil.which("aprun"): - hostlist = get_hostlist() - else: - hostlist = None - return Orchestrator( - db_nodes=nodes, - port=test_port, - batch=batch, - interface=test_nic, - launcher=test_launcher, - hosts=hostlist, - ) - if test_launcher == "pals": - hostlist = get_hostlist() - return Orchestrator( - db_nodes=nodes, - port=test_port, - batch=batch, - interface=test_nic, - launcher=test_launcher, - hosts=hostlist, - ) - if test_launcher == "slurm": - return Orchestrator( - db_nodes=nodes, - port=test_port, - batch=batch, - interface=test_nic, - launcher=test_launcher, - ) - if test_launcher == "dragon": - return Orchestrator( - db_nodes=nodes, - port=test_port, - batch=batch, - interface=test_nic, - launcher=test_launcher, - ) - if test_launcher == "lsf": - return Orchestrator( - db_nodes=nodes, - port=test_port, - batch=batch, - cpus_per_shard=4, - gpus_per_shard=2 if test_device == "GPU" else 0, - project=get_account(), - interface=test_nic, - launcher=test_launcher, - ) - - return Orchestrator(port=test_port, interface="lo") - @staticmethod def choose_host(rs: RunSettings) -> t.Optional[str]: if isinstance(rs, (MpirunSettings, MpiexecSettings)): @@ -485,65 +449,6 @@ def choose_host(rs: RunSettings) -> t.Optional[str]: return None -@pytest.fixture -def local_db( - request: t.Any, wlmutils: t.Type[WLMUtils], test_dir: str -) -> t.Generator[Orchestrator, None, None]: - """Yield fixture for startup and teardown of an local orchestrator""" - - exp_name = request.function.__name__ - exp = Experiment(exp_name, launcher="local", exp_path=test_dir) - db = Orchestrator(port=wlmutils.get_test_port(), interface="lo") - db.set_path(test_dir) - exp.start(db) - - yield db - # pass or fail, the teardown code below is ran after the - # completion of a test case that uses this fixture - exp.stop(db) - - -@pytest.fixture -def db( - request: t.Any, wlmutils: t.Type[WLMUtils], test_dir: str -) -> t.Generator[Orchestrator, None, None]: - """Yield fixture for startup and teardown of an orchestrator""" - launcher = wlmutils.get_test_launcher() - - exp_name = request.function.__name__ - exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir) - db = wlmutils.get_orchestrator() - db.set_path(test_dir) - exp.start(db) - - yield db - # pass or fail, the teardown code below is ran after the - # completion of a test case that uses this fixture - exp.stop(db) - - -@pytest.fixture -def db_cluster( - test_dir: str, wlmutils: t.Type[WLMUtils], request: t.Any -) -> t.Generator[Orchestrator, None, None]: - """ - Yield fixture for startup and teardown of a clustered orchestrator. - This should only be used in on_wlm and full_wlm tests. - """ - launcher = wlmutils.get_test_launcher() - - exp_name = request.function.__name__ - exp = Experiment(exp_name, launcher=launcher, exp_path=test_dir) - db = wlmutils.get_orchestrator(nodes=3) - db.set_path(test_dir) - exp.start(db) - - yield db - # pass or fail, the teardown code below is ran after the - # completion of a test case that uses this fixture - exp.stop(db) - - @pytest.fixture(scope="function", autouse=True) def environment_cleanup(monkeypatch: pytest.MonkeyPatch) -> None: for key in os.environ.keys(): @@ -750,7 +655,7 @@ def setup_test_colo( db_args: t.Dict[str, t.Any], colo_settings: t.Optional[RunSettings] = None, colo_model_name: str = "colocated_model", - port: int = test_port, + port: t.Optional[int] = None, on_wlm: bool = False, ) -> Model: """Setup database needed for the colo pinning tests""" @@ -766,10 +671,11 @@ def setup_test_colo( if on_wlm: colo_settings.set_tasks(1) colo_settings.set_nodes(1) + colo_model = exp.create_model(colo_model_name, colo_settings) if db_type in ["tcp", "deprecated"]: - db_args["port"] = port + db_args["port"] = port if port is not None else _find_free_port(test_ports) db_args["ifname"] = "lo" if db_type == "uds" and colo_model_name is not None: tmp_dir = tempfile.gettempdir() @@ -968,3 +874,151 @@ def num_calls(self) -> int: @property def details(self) -> t.List[t.Tuple[t.Tuple[t.Any, ...], t.Dict[str, t.Any]]]: return self._details + +## Reuse database across tests + +database_registry: t.DefaultDict[str, t.Optional[Orchestrator]] = defaultdict(lambda: None) + +@pytest.fixture(scope="function") +def local_experiment(test_dir: str) -> smartsim.Experiment: + """Create a default experiment that uses the requested launcher""" + name = pathlib.Path(test_dir).stem + return smartsim.Experiment(name, exp_path=test_dir, launcher="local") + +@pytest.fixture(scope="function") +def wlm_experiment(test_dir: str, wlmutils: WLMUtils) -> smartsim.Experiment: + """Create a default experiment that uses the requested launcher""" + name = pathlib.Path(test_dir).stem + return smartsim.Experiment( + name, + exp_path=test_dir, + launcher=wlmutils.get_test_launcher() + ) + +def _cleanup_db(name: str) -> None: + global database_registry + db = database_registry[name] + if db and db.is_active(): + exp = Experiment("cleanup") + try: + db = exp.reconnect_orchestrator(db.checkpoint_file) + exp.stop(db) + except: + pass + +@dataclass +class DBConfiguration: + name: str + launcher: str + num_nodes: int + interface: t.Union[str,t.List[str]] + hostlist: t.Optional[t.List[str]] + port: int + +@dataclass +class PrepareDatabaseOutput: + orchestrator: t.Optional[Orchestrator] # The actual orchestrator object + new_db: bool # True if a new database was created when calling prepare_db + +# Reuse databases +@pytest.fixture(scope="session") +def local_db() -> t.Generator[DBConfiguration, None, None]: + name = "local_db_fixture" + config = DBConfiguration( + name, + "local", + 1, + "lo", + None, + _find_free_port(tuple(reversed(test_ports))), + ) + yield config + _cleanup_db(name) + +@pytest.fixture(scope="session") +def single_db(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None]: + hostlist = wlmutils.get_test_hostlist() + hostlist = hostlist[-1:] if hostlist is not None else None + name = "single_db_fixture" + config = DBConfiguration( + name, + wlmutils.get_test_launcher(), + 1, + wlmutils.get_test_interface(), + hostlist, + _find_free_port(tuple(reversed(test_ports))) + ) + yield config + _cleanup_db(name) + + +@pytest.fixture(scope="session") +def clustered_db(wlmutils: WLMUtils) -> t.Generator[DBConfiguration, None, None]: + hostlist = wlmutils.get_test_hostlist() + hostlist = hostlist[-4:-1] if hostlist is not None else None + name = "clustered_db_fixture" + config = DBConfiguration( + name, + wlmutils.get_test_launcher(), + 3, + wlmutils.get_test_interface(), + hostlist, + _find_free_port(tuple(reversed(test_ports))), + ) + yield config + _cleanup_db(name) + + +@pytest.fixture +def register_new_db() -> t.Callable[[DBConfiguration], Orchestrator]: + def _register_new_db( + config: DBConfiguration + ) -> Orchestrator: + exp_path = pathlib.Path(test_output_root, config.name) + exp_path.mkdir(exist_ok=True) + exp = Experiment( + config.name, + exp_path=str(exp_path), + launcher=config.launcher, + ) + orc = exp.create_database( + port=config.port, + batch=False, + interface=config.interface, + hosts=config.hostlist, + db_nodes=config.num_nodes + ) + exp.generate(orc, overwrite=True) + exp.start(orc) + global database_registry + database_registry[config.name] = orc + return orc + return _register_new_db + + +@pytest.fixture(scope="function") +def prepare_db( + register_new_db: t.Callable[ + [DBConfiguration], + Orchestrator + ] +) -> t.Callable[ + [DBConfiguration], + PrepareDatabaseOutput +]: + def _prepare_db(db_config: DBConfiguration) -> PrepareDatabaseOutput: + global database_registry + db = database_registry[db_config.name] + + new_db = False + db_up = False + + if db: + db_up = db.is_active() + + if not db_up or db is None: + db = register_new_db(db_config) + new_db = True + + return PrepareDatabaseOutput(db, new_db) + return _prepare_db diff --git a/doc/api/smartsim_api.rst b/doc/api/smartsim_api.rst index 88c173783..d9615e04c 100644 --- a/doc/api/smartsim_api.rst +++ b/doc/api/smartsim_api.rst @@ -433,6 +433,8 @@ Orchestrator Orchestrator.set_max_message_size Orchestrator.set_db_conf Orchestrator.telemetry + Orchestrator.checkpoint_file + Orchestrator.batch Orchestrator ------------ diff --git a/doc/changelog.md b/doc/changelog.md index afb5b8e90..be971bce7 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -17,6 +17,7 @@ Description - Add dragon runtime installer - Add launcher based on Dragon +- Reuse Orchestrators within the testing suite to improve performance. - Fix building of documentation - Preview entities on experiment before start - Update authentication in release workflow @@ -70,6 +71,10 @@ Detailed Notes or by using ``DragonRunSettings`` to launch a job. The Dragon launcher is at an early stage of development: early adopters are referred to the dedicated documentation section to learn more about it. ([SmartSim-PR580](https://github.com/CrayLabs/SmartSim/pull/580)) +- Tests may now request a given configuration and will reconnect to + the existing orchestrator instead of building up and tearing down + a new one each test. + ([SmartSim-PR567](https://github.com/CrayLabs/SmartSim/pull/567)) - Manually ensure that typing_extensions==4.6.1 in Dockerfile used to build docs. This fixes the deploy_dev_docs Github action ([SmartSim-PR564](https://github.com/CrayLabs/SmartSim/pull/564)) - Added preview functionality to Experiment, including preview of all entities, active infrastructure and @@ -118,7 +123,7 @@ Detailed Notes Torch will unconditionally try to link in this library, however fails because the linking flags are incorrect. ([SmartSim-PR538](https://github.com/CrayLabs/SmartSim/pull/538)) -- Change type_extension and pydantic versions in readthedocs +- Change typing\_extensions and pydantic versions in readthedocs environment to enable docs build. ([SmartSim-PR537](https://github.com/CrayLabs/SmartSim/pull/537)) - Promote devices to a dedicated Enum type throughout the SmartSim diff --git a/doc/testing.rst b/doc/testing.rst index ccb2db3c2..08cce5d36 100644 --- a/doc/testing.rst +++ b/doc/testing.rst @@ -66,20 +66,20 @@ of the tests located within the ``on_wlm`` directory. To run the ``on_wlm`` test suite, users will have to be on a system with one of the supported workload managers. Additionally, users will -need to obtain an allocation of **at least 4 nodes**. +need to obtain an allocation of **at least 8 nodes**. Examples of how to obtain allocations on systems with the launchers: .. code:: bash # for slurm (with srun) - salloc -N 4 -A account --exclusive -t 00:10:00 + salloc -N 8 -A account --exclusive -t 00:10:00 # for PBSPro (with aprun) - qsub -l select=4 -l place=scatter -l walltime=00:10:00 -q queue + qsub -l select=8 -l place=scatter -l walltime=00:10:00 -q queue # for LSF (with jsrun) - bsub -Is -W 00:30 -nnodes 4 -P project $SHELL + bsub -Is -W 00:30 -nnodes 8 -P project $SHELL Values for queue, account, or project should be substituted appropriately. @@ -119,7 +119,7 @@ A full example on an internal SLURM system .. code:: bash - salloc -N 4 -A account --exclusive -t 03:00:00 + salloc -N 8 -A account --exclusive -t 03:00:00 export SMARTSIM_TEST_LAUNCHER=slurm export SMARTSIM_TEST_INTERFACE=ipogif0 export SMARTSIM_TEST_DEVICE=gpu diff --git a/smartsim/_core/config/config.py b/smartsim/_core/config/config.py index 3e43d7de4..9cf950b21 100644 --- a/smartsim/_core/config/config.py +++ b/smartsim/_core/config/config.py @@ -203,8 +203,14 @@ def test_num_gpus(self) -> int: # pragma: no cover return int(os.environ.get("SMARTSIM_TEST_NUM_GPUS") or 1) @property - def test_port(self) -> int: # pragma: no cover - return int(os.environ.get("SMARTSIM_TEST_PORT", 6780)) + def test_ports(self) -> t.Sequence[int]: # pragma: no cover + min_required_ports = 25 + first_port = int(os.environ.get("SMARTSIM_TEST_PORT", 6780)) + num_ports = max( + int(os.environ.get("SMARTSIM_TEST_NUM_PORTS", min_required_ports)), + min_required_ports, + ) + return range(first_port, first_port + num_ports) @property def test_batch_resources(self) -> t.Dict[t.Any, t.Any]: # pragma: no cover diff --git a/smartsim/_core/control/controller.py b/smartsim/_core/control/controller.py index 2aaf8dad7..43a218545 100644 --- a/smartsim/_core/control/controller.py +++ b/smartsim/_core/control/controller.py @@ -27,6 +27,7 @@ from __future__ import annotations import itertools +import os import os.path as osp import pathlib import pickle @@ -36,7 +37,6 @@ import threading import time import typing as t -from os import environ from smartredis import Client, ConfigOptions @@ -749,14 +749,26 @@ def _save_orchestrator(self, orchestrator: Orchestrator) -> None: :param orchestrator: Orchestrator configuration to be saved """ - dat_file = "/".join((orchestrator.path, "smartsim_db.dat")) - db_jobs = self._jobs.db_jobs - orc_data = {"db": orchestrator, "db_jobs": db_jobs} - steps = [] - for db_job in db_jobs.values(): - steps.append(self._launcher.step_mapping[db_job.name]) - orc_data["steps"] = steps - with open(dat_file, "wb") as pickle_file: + if not orchestrator.is_active(): + raise Exception("Orchestrator is not running") + + # Extract only the db_jobs associated with this particular orchestrator + if orchestrator.batch: + job_names = [orchestrator.name] + else: + job_names = [dbnode.name for dbnode in orchestrator.entities] + db_jobs = { + name: job for name, job in self._jobs.db_jobs.items() if name in job_names + } + + # Extract the associated steps + steps = [ + self._launcher.step_mapping[db_job.name] for db_job in db_jobs.values() + ] + + orc_data = {"db": orchestrator, "db_jobs": db_jobs, "steps": steps} + + with open(orchestrator.checkpoint_file, "wb") as pickle_file: pickle.dump(orc_data, pickle_file) def _orchestrator_launch_wait(self, orchestrator: Orchestrator) -> None: @@ -787,8 +799,7 @@ def _orchestrator_launch_wait(self, orchestrator: Orchestrator) -> None: statuses = self.get_entity_list_status(orchestrator) if all(stat == SmartSimStatus.STATUS_RUNNING for stat in statuses): ready = True - # TODO remove in favor of by node status check - time.sleep(CONFIG.jm_interval) + # TODO: Add a node status check elif any(stat in TERMINAL_STATUSES for stat in statuses): self.stop_db(orchestrator) msg = "Orchestrator failed during startup" @@ -806,14 +817,14 @@ def _orchestrator_launch_wait(self, orchestrator: Orchestrator) -> None: # launch explicitly raise - def reload_saved_db(self, checkpoint_file: str) -> Orchestrator: + def reload_saved_db( + self, checkpoint_file: t.Union[str, os.PathLike[str]] + ) -> Orchestrator: with JM_LOCK: - if self.orchestrator_active: - raise SmartSimError("Orchestrator exists and is active") if not osp.exists(checkpoint_file): raise FileNotFoundError( - f"The SmartSim database config file {checkpoint_file} " + f"The SmartSim database config file {os.fspath(checkpoint_file)} " "cannot be found." ) @@ -878,9 +889,9 @@ def _set_dbobjects(self, manifest: Manifest) -> None: if not db_is_active(hosts=hosts, ports=ports, num_shards=len(db_addresses)): raise SSInternalError("Cannot set DB Objects, DB is not running") - environ[f"SSDB{db_name}"] = db_addresses[0] + os.environ[f"SSDB{db_name}"] = db_addresses[0] - environ[f"SR_DB_TYPE{db_name}"] = ( + os.environ[f"SR_DB_TYPE{db_name}"] = ( CLUSTERED if len(db_addresses) > 1 else STANDALONE ) diff --git a/smartsim/_core/launcher/dragon/dragonLauncher.py b/smartsim/_core/launcher/dragon/dragonLauncher.py index 041257366..c13eefedd 100644 --- a/smartsim/_core/launcher/dragon/dragonLauncher.py +++ b/smartsim/_core/launcher/dragon/dragonLauncher.py @@ -185,7 +185,7 @@ def run(self, step: Step) -> t.Optional[str]: ), DragonRunResponse, ) - step_id = task_id = str(response.step_id) + step_id = str(response.step_id) else: # pylint: disable-next=consider-using-with out_strm = open(out, "w+", encoding="utf-8") diff --git a/smartsim/database/orchestrator.py b/smartsim/database/orchestrator.py index f48ef2857..f6ce0310f 100644 --- a/smartsim/database/orchestrator.py +++ b/smartsim/database/orchestrator.py @@ -27,6 +27,7 @@ # pylint: disable=too-many-lines import itertools +import os.path as osp import sys import typing as t from os import environ, getcwd, getenv @@ -41,7 +42,12 @@ from .._core.utils.helpers import is_valid_cmd, unpack_db_identifier from .._core.utils.network import get_ip_from_host from ..entity import DBNode, EntityList, TelemetryConfiguration -from ..error import SmartSimError, SSConfigError, SSUnsupportedError +from ..error import ( + SmartSimError, + SSConfigError, + SSDBFilesNotParseable, + SSUnsupportedError, +) from ..log import get_logger from ..servertype import CLUSTERED, STANDALONE from ..settings import ( @@ -147,6 +153,7 @@ def _check_local_constraints(launcher: str, batch: bool) -> None: raise SmartSimError(msg) +# pylint: disable-next=too-many-public-methods class Orchestrator(EntityList[DBNode]): """The Orchestrator is an in-memory database that can be launched alongside entities in SmartSim. Data can be transferred between @@ -370,10 +377,11 @@ def is_active(self) -> bool: :return: True if database is active, False otherwise """ - if not self._hosts: + try: + hosts = self.hosts + except SSDBFilesNotParseable: return False - - return db_is_active(self._hosts, self.ports, self.num_shards) + return db_is_active(hosts, self.ports, self.num_shards) @property def _rai_module(self) -> t.Tuple[str, ...]: @@ -399,6 +407,14 @@ def _redis_exe(self) -> str: def _redis_conf(self) -> str: return CONFIG.database_conf + @property + def checkpoint_file(self) -> str: + """Get the path to the checkpoint file for this Orchestrator + + :return: Path to the checkpoint file if it exists, otherwise a None + """ + return osp.join(self.path, "smartsim_db.dat") + def set_cpus(self, num_cpus: int) -> None: """Set the number of CPUs available to each database shard @@ -451,9 +467,8 @@ def set_hosts(self, host_list: t.Union[t.List[str], str]) -> None: raise TypeError("host_list argument must be list of strings") self._user_hostlist = host_list.copy() # TODO check length - if self.batch: - if hasattr(self, "batch_settings") and self.batch_settings: - self.batch_settings.set_hostlist(host_list) + if self.batch and hasattr(self, "batch_settings") and self.batch_settings: + self.batch_settings.set_hostlist(host_list) if self.launcher == "lsf": for db in self.entities: diff --git a/smartsim/entity/dbnode.py b/smartsim/entity/dbnode.py index 485bbcd88..d371357f8 100644 --- a/smartsim/entity/dbnode.py +++ b/smartsim/entity/dbnode.py @@ -34,7 +34,7 @@ from dataclasses import dataclass from .._core.config import CONFIG -from ..error import SmartSimError +from ..error import SSDBFilesNotParseable from ..log import get_logger from ..settings.base import RunSettings from .entity import SmartSimEntity @@ -184,7 +184,7 @@ def _parse_launched_shard_info_from_files( def get_launched_shard_info(self) -> "t.List[LaunchedShardData]": """Parse the launched database shard info from the output files - :raises SmartSimError: if all shard info could not be found + :raises SSDBFilesNotParseable: if all shard info could not be found :return: The found launched shard info """ ips: "t.List[LaunchedShardData]" = [] @@ -211,7 +211,7 @@ def get_launched_shard_info(self) -> "t.List[LaunchedShardData]": f"{len(ips)} out of {self.num_shards} DB shards." ) logger.error(msg) - raise SmartSimError(msg) + raise SSDBFilesNotParseable(msg) return ips def _parse_db_hosts(self) -> t.List[str]: @@ -220,7 +220,7 @@ def _parse_db_hosts(self) -> t.List[str]: The IP address is preferred, but if hostname is only present then a lookup to /etc/hosts is done through the socket library. - :raises SmartSimError: if host/ip could not be found + :raises SSDBFilesNotParseable: if host/ip could not be found :return: ip addresses | hostnames """ return list({shard.hostname for shard in self.get_launched_shard_info()}) diff --git a/smartsim/entity/entityList.py b/smartsim/entity/entityList.py index 6d958bda6..edaa88668 100644 --- a/smartsim/entity/entityList.py +++ b/smartsim/entity/entityList.py @@ -91,16 +91,14 @@ def db_scripts(self) -> t.Iterable["smartsim.entity.DBScript"]: @property def batch(self) -> bool: - try: - if not hasattr(self, "batch_settings"): - return False - - if self.batch_settings: - return True - return False - # local orchestrator cannot launch with batches - except AttributeError: - return False + """Property indicating whether or not the entity sequence should be + launched as a batch job + + :return: ``True`` if entity sequence should be launched as a batch job, + ``False`` if the members will be launched individually. + """ + # pylint: disable-next=no-member + return hasattr(self, "batch_settings") and self.batch_settings @property def type(self) -> str: diff --git a/smartsim/error/__init__.py b/smartsim/error/__init__.py index 4268905e6..3a40548e7 100644 --- a/smartsim/error/__init__.py +++ b/smartsim/error/__init__.py @@ -32,6 +32,7 @@ ShellError, SmartSimError, SSConfigError, + SSDBFilesNotParseable, SSDBIDConflictError, SSInternalError, SSReservedKeywordError, diff --git a/smartsim/error/errors.py b/smartsim/error/errors.py index 65bc10857..333258a34 100644 --- a/smartsim/error/errors.py +++ b/smartsim/error/errors.py @@ -87,6 +87,12 @@ class SSDBIDConflictError(SmartSimError): """ +class SSDBFilesNotParseable(SmartSimError): + """Raised when the files related to the database cannot be parsed. + Includes the case when the files do not exist. + """ + + # Internal Exceptions diff --git a/smartsim/experiment.py b/smartsim/experiment.py index cee3872bb..6b9d6a1fb 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -733,7 +733,7 @@ def create_database( batch: bool = False, hosts: t.Optional[t.Union[t.List[str], str]] = None, run_command: str = "auto", - interface: str = "ipogif0", + interface: t.Union[str, t.List[str]] = "ipogif0", account: t.Optional[str] = None, time: t.Optional[str] = None, queue: t.Optional[str] = None, diff --git a/smartsim/ml/data.py b/smartsim/ml/data.py index a24dabc9f..6175259b2 100644 --- a/smartsim/ml/data.py +++ b/smartsim/ml/data.py @@ -285,6 +285,7 @@ def __init__( verbose: bool = False, init_samples: bool = True, max_fetch_trials: int = -1, + wait_interval: float = 10.0, ) -> None: self.address = address self.cluster = cluster @@ -311,7 +312,7 @@ def __init__( self.set_replica_parameters(replica_rank, num_replicas) if init_samples: - self.init_samples(max_fetch_trials) + self.init_samples(max_fetch_trials, wait_interval) @property def client(self) -> Client: @@ -378,7 +379,7 @@ def __iter__( self._data_generation(self._calc_indices(idx)) for idx in range(len(self)) ) - def init_samples(self, init_trials: int = -1) -> None: + def init_samples(self, init_trials: int = -1, wait_interval: float = 10.0) -> None: """Initialize samples (and targets, if needed). A new attempt to download samples will be made every ten seconds, @@ -392,10 +393,10 @@ def init_samples(self, init_trials: int = -1) -> None: max_trials = init_trials or -1 while not self and num_trials != max_trials: self._update_samples_and_targets() - self.log( - "DataLoader could not download samples, will try again in 10 seconds" - ) - time.sleep(10) + msg = "DataLoader could not download samples, will try again in " + msg += f"{wait_interval} seconds" + self.log(msg) + time.sleep(wait_interval) num_trials += 1 if not self: diff --git a/tests/backends/test_cli_mini_exp.py b/tests/backends/test_cli_mini_exp.py index f7563fc96..2fde2ff5f 100644 --- a/tests/backends/test_cli_mini_exp.py +++ b/tests/backends/test_cli_mini_exp.py @@ -48,6 +48,7 @@ def test_cli_mini_exp_doesnt_error_out_with_dev_build( + prepare_db, local_db, test_dir, monkeypatch, @@ -57,9 +58,11 @@ def test_cli_mini_exp_doesnt_error_out_with_dev_build( to ensure that it does not accidentally report false positive/negatives """ + db = prepare_db(local_db).orchestrator + @contextmanager def _mock_make_managed_local_orc(*a, **kw): - (client_addr,) = local_db.get_address() + (client_addr,) = db.get_address() yield smartredis.Client(False, address=client_addr) monkeypatch.setattr( @@ -68,7 +71,7 @@ def _mock_make_managed_local_orc(*a, **kw): _mock_make_managed_local_orc, ) backends = installed_redisai_backends() - (db_port,) = local_db.ports + (db_port,) = db.ports smartsim._core._cli.validate.test_install( # Shouldn't matter bc we are stubbing creation of orc diff --git a/tests/backends/test_dataloader.py b/tests/backends/test_dataloader.py index e377f5631..de4bf6d8e 100644 --- a/tests/backends/test_dataloader.py +++ b/tests/backends/test_dataloader.py @@ -167,19 +167,16 @@ def train_tf(generator): @pytest.mark.skipif(not shouldrun_tf, reason="Test needs TensorFlow to run") -def test_tf_dataloaders(test_dir, wlmutils): - exp = Experiment( - "test_tf_dataloaders", test_dir, launcher=wlmutils.get_test_launcher() - ) - orc: Orchestrator = wlmutils.get_orchestrator() - exp.generate(orc) - exp.start(orc) +def test_tf_dataloaders(wlm_experiment, prepare_db, single_db, monkeypatch): + + db = prepare_db(single_db).orchestrator + orc = wlm_experiment.reconnect_orchestrator(db.checkpoint_file) + monkeypatch.setenv("SSDB", orc.get_address()[0]) + monkeypatch.setenv("SSKEYIN", "test_uploader_0,test_uploader_1") try: - os.environ["SSDB"] = orc.get_address()[0] data_info = run_local_uploaders(mpi_size=2, format="tf") - os.environ["SSKEYIN"] = "test_uploader_0,test_uploader_1" for rank in range(2): tf_dynamic = TFDataGenerator( data_info_or_list_name="test_data_list", @@ -190,6 +187,7 @@ def test_tf_dataloaders(test_dir, wlmutils): batch_size=4, max_fetch_trials=5, dynamic=False, # catch wrong arg + wait_interval=0.1, ) train_tf(tf_dynamic) assert len(tf_dynamic) == 4 @@ -204,6 +202,7 @@ def test_tf_dataloaders(test_dir, wlmutils): batch_size=4, max_fetch_trials=5, dynamic=True, # catch wrong arg + wait_interval=0.1, ) train_tf(tf_static) assert len(tf_static) == 4 @@ -211,11 +210,6 @@ def test_tf_dataloaders(test_dir, wlmutils): except Exception as e: raise e - finally: - exp.stop(orc) - os.environ.pop("SSDB", "") - os.environ.pop("SSKEYIN", "") - os.environ.pop("SSKEYOUT", "") def create_trainer_torch(experiment: Experiment, filedir, wlmutils): @@ -234,20 +228,18 @@ def create_trainer_torch(experiment: Experiment, filedir, wlmutils): @pytest.mark.skipif(not shouldrun_torch, reason="Test needs Torch to run") -def test_torch_dataloaders(fileutils, test_dir, wlmutils): - exp = Experiment( - "test_tf_dataloaders", test_dir, launcher=wlmutils.get_test_launcher() - ) - orc: Orchestrator = wlmutils.get_orchestrator() +def test_torch_dataloaders( + wlm_experiment, prepare_db, single_db, fileutils, test_dir, wlmutils, monkeypatch +): config_dir = fileutils.get_test_dir_path("ml") - exp.generate(orc) - exp.start(orc) + db = prepare_db(single_db).orchestrator + orc = wlm_experiment.reconnect_orchestrator(db.checkpoint_file) + monkeypatch.setenv("SSDB", orc.get_address()[0]) + monkeypatch.setenv("SSKEYIN", "test_uploader_0,test_uploader_1") try: - os.environ["SSDB"] = orc.get_address()[0] data_info = run_local_uploaders(mpi_size=2) - os.environ["SSKEYIN"] = "test_uploader_0,test_uploader_1" for rank in range(2): torch_dynamic = TorchDataGenerator( data_info_or_list_name="test_data_list", @@ -258,11 +250,12 @@ def test_torch_dataloaders(fileutils, test_dir, wlmutils): batch_size=4, max_fetch_trials=5, dynamic=False, # catch wrong arg - init_samples=True, # catch wrong arg + init_samples=True, + wait_interval=0.1, ) check_dataloader(torch_dynamic, rank, dynamic=True) - torch_dynamic.init_samples(5) + torch_dynamic.init_samples(5, 0.1) for _ in range(2): for _ in torch_dynamic: continue @@ -278,26 +271,22 @@ def test_torch_dataloaders(fileutils, test_dir, wlmutils): max_fetch_trials=5, dynamic=True, # catch wrong arg init_samples=True, # catch wrong arg + wait_interval=0.1, ) check_dataloader(torch_static, rank, dynamic=False) - torch_static.init_samples(5) + torch_static.init_samples(5, 0.1) for _ in range(2): for _ in torch_static: continue - trainer = create_trainer_torch(exp, config_dir, wlmutils) - exp.start(trainer, block=True) + trainer = create_trainer_torch(wlm_experiment, config_dir, wlmutils) + wlm_experiment.start(trainer, block=True) - assert exp.get_status(trainer)[0] == SmartSimStatus.STATUS_COMPLETED + assert wlm_experiment.get_status(trainer)[0] == SmartSimStatus.STATUS_COMPLETED except Exception as e: raise e - finally: - exp.stop(orc) - os.environ.pop("SSDB", "") - os.environ.pop("SSKEYIN", "") - os.environ.pop("SSKEYOUT", "") def test_data_info_repr(): @@ -331,15 +320,9 @@ def test_data_info_repr(): @pytest.mark.skipif( not (shouldrun_torch or shouldrun_tf), reason="Requires TF or PyTorch" ) -def test_wrong_dataloaders(test_dir, wlmutils): - exp = Experiment( - "test-wrong-dataloaders", - exp_path=test_dir, - launcher=wlmutils.get_test_launcher(), - ) - orc = wlmutils.get_orchestrator() - exp.generate(orc) - exp.start(orc) +def test_wrong_dataloaders(wlm_experiment, prepare_db, single_db): + db = prepare_db(single_db).orchestrator + orc = wlm_experiment.reconnect_orchestrator(db.checkpoint_file) if shouldrun_tf: with pytest.raises(SSInternalError): @@ -365,5 +348,3 @@ def test_wrong_dataloaders(test_dir, wlmutils): cluster=False, ) torch_data_gen.init_samples(init_trials=1) - - exp.stop(orc) diff --git a/tests/backends/test_dbmodel.py b/tests/backends/test_dbmodel.py index eb0198229..6155b6884 100644 --- a/tests/backends/test_dbmodel.py +++ b/tests/backends/test_dbmodel.py @@ -146,36 +146,30 @@ def save_torch_cnn(path, file_name): @pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -def test_tf_db_model(fileutils, test_dir, wlmutils, mlutils): +def test_tf_db_model( + wlm_experiment, prepare_db, single_db, fileutils, test_dir, mlutils +): """Test TensorFlow DB Models on remote DB""" - # Set experiment name - exp_name = "test-tf-db-model" - # Retrieve parameters from testing environment - test_launcher = wlmutils.get_test_launcher() - test_interface = wlmutils.get_test_interface() - test_port = wlmutils.get_test_port() test_device = mlutils.get_test_device() test_num_gpus = 1 # TF backend fails on multiple GPUs test_script = fileutils.get_test_conf_path("run_tf_dbmodel_smartredis.py") - # Create the SmartSim Experiment - exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) - # Create RunSettings - run_settings = exp.create_run_settings(exe=sys.executable, exe_args=test_script) + run_settings = wlm_experiment.create_run_settings( + exe=sys.executable, exe_args=test_script + ) run_settings.set_nodes(1) run_settings.set_tasks(1) # Create Model - smartsim_model = exp.create_model("smartsim_model", run_settings) + smartsim_model = wlm_experiment.create_model("smartsim_model", run_settings) # Create database - host = wlmutils.choose_host(run_settings) - db = exp.create_database(port=test_port, interface=test_interface, hosts=host) - exp.generate(db) + db = prepare_db(single_db).orchestrator + wlm_experiment.reconnect_orchestrator(db.checkpoint_file) # Create and save ML model to filesystem model, inputs, outputs = create_tf_cnn() @@ -212,50 +206,41 @@ def test_tf_db_model(fileutils, test_dir, wlmutils, mlutils): # Assert we have added both models assert len(smartsim_model._db_models) == 2 - exp.generate(smartsim_model) + wlm_experiment.generate(smartsim_model) # Launch and check successful completion - try: - exp.start(db, smartsim_model, block=True) - statuses = exp.get_status(smartsim_model) - assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses - ), f"Statuses: {statuses}" - finally: - exp.stop(db) + wlm_experiment.start(smartsim_model, block=True) + statuses = wlm_experiment.get_status(smartsim_model) + assert all( + stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses + ), f"Statuses: {statuses}" @pytest.mark.skipif(not should_run_pt, reason="Test needs PyTorch to run") -def test_pt_db_model(fileutils, test_dir, wlmutils, mlutils): +def test_pt_db_model( + wlm_experiment, prepare_db, single_db, fileutils, test_dir, mlutils +): """Test PyTorch DB Models on remote DB""" - # Set experiment name - exp_name = "test-pt-db-model" - # Retrieve parameters from testing environment - test_launcher = wlmutils.get_test_launcher() - test_interface = wlmutils.get_test_interface() - test_port = wlmutils.get_test_port() test_device = mlutils.get_test_device() test_num_gpus = mlutils.get_test_num_gpus() if pytest.test_device == "GPU" else 1 test_script = fileutils.get_test_conf_path("run_pt_dbmodel_smartredis.py") - # Create the SmartSim Experiment - exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) - # Create RunSettings - run_settings = exp.create_run_settings(exe=sys.executable, exe_args=test_script) + run_settings = wlm_experiment.create_run_settings( + exe=sys.executable, exe_args=test_script + ) run_settings.set_nodes(1) run_settings.set_tasks(1) # Create Model - smartsim_model = exp.create_model("smartsim_model", run_settings) + smartsim_model = wlm_experiment.create_model("smartsim_model", run_settings) # Create database - host = wlmutils.choose_host(run_settings) - db = exp.create_database(port=test_port, interface=test_interface, hosts=host) - exp.generate(db) + db = prepare_db(single_db).orchestrator + wlm_experiment.reconnect_orchestrator(db.checkpoint_file) # Create and save ML model to filesystem save_torch_cnn(test_dir, "model1.pt") @@ -279,55 +264,46 @@ def test_pt_db_model(fileutils, test_dir, wlmutils, mlutils): # Assert we have added both models assert len(smartsim_model._db_models) == 1 - exp.generate(smartsim_model) + wlm_experiment.generate(smartsim_model) # Launch and check successful completion - try: - exp.start(db, smartsim_model, block=True) - statuses = exp.get_status(smartsim_model) - assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses - ), f"Statuses: {statuses}" - finally: - exp.stop(db) + wlm_experiment.start(smartsim_model, block=True) + statuses = wlm_experiment.get_status(smartsim_model) + assert all( + stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses + ), f"Statuses: {statuses}" @pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") -def test_db_model_ensemble(fileutils, test_dir, wlmutils, mlutils): +def test_db_model_ensemble( + wlm_experiment, prepare_db, single_db, fileutils, test_dir, wlmutils, mlutils +): """Test DBModels on remote DB, with an ensemble""" - # Set experiment name - exp_name = "test-db-model-ensemble" - # Retrieve parameters from testing environment - test_launcher = wlmutils.get_test_launcher() - test_interface = wlmutils.get_test_interface() - test_port = wlmutils.get_test_port() test_device = mlutils.get_test_device() test_num_gpus = 1 # TF backend fails on multiple GPUs test_script = fileutils.get_test_conf_path("run_tf_dbmodel_smartredis.py") - # Create the SmartSim Experiment - exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) - # Create RunSettings - run_settings = exp.create_run_settings(exe=sys.executable, exe_args=test_script) + run_settings = wlm_experiment.create_run_settings( + exe=sys.executable, exe_args=test_script + ) run_settings.set_nodes(1) run_settings.set_tasks(1) # Create ensemble - smartsim_ensemble = exp.create_ensemble( + smartsim_ensemble = wlm_experiment.create_ensemble( "smartsim_model", run_settings=run_settings, replicas=2 ) # Create Model - smartsim_model = exp.create_model("smartsim_model", run_settings) + smartsim_model = wlm_experiment.create_model("smartsim_model", run_settings) # Create database - host = wlmutils.choose_host(run_settings) - db = exp.create_database(port=test_port, interface=test_interface, hosts=host) - exp.generate(db) + db = prepare_db(single_db).orchestrator + wlm_experiment.reconnect_orchestrator(db.checkpoint_file) # Create and save ML model to filesystem model, inputs, outputs = create_tf_cnn() @@ -380,17 +356,14 @@ def test_db_model_ensemble(fileutils, test_dir, wlmutils, mlutils): # Assert we have added two models to each entity assert all([len(entity._db_models) == 2 for entity in smartsim_ensemble]) - exp.generate(smartsim_ensemble) + wlm_experiment.generate(smartsim_ensemble) # Launch and check successful completion - try: - exp.start(db, smartsim_ensemble, block=True) - statuses = exp.get_status(smartsim_ensemble) - assert all( - stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses - ), f"Statuses: {statuses}" - finally: - exp.stop(db) + wlm_experiment.start(smartsim_ensemble, block=True) + statuses = wlm_experiment.get_status(smartsim_ensemble) + assert all( + stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses + ), f"Statuses: {statuses}" @pytest.mark.skipif(not should_run_tf, reason="Test needs TF to run") diff --git a/tests/backends/test_dbscript.py b/tests/backends/test_dbscript.py index 9d0b04c8e..2c04bf5db 100644 --- a/tests/backends/test_dbscript.py +++ b/tests/backends/test_dbscript.py @@ -57,37 +57,29 @@ def timestwo(x): @pytest.mark.skipif(not should_run, reason="Test needs Torch to run") -def test_db_script(fileutils, test_dir, wlmutils, mlutils): +def test_db_script(wlm_experiment, prepare_db, single_db, fileutils, mlutils): """Test DB scripts on remote DB""" - # Set experiment name - exp_name = "test-db-script" - - # Retrieve parameters from testing environment - test_launcher = wlmutils.get_test_launcher() - test_interface = wlmutils.get_test_interface() - test_port = wlmutils.get_test_port() test_device = mlutils.get_test_device() test_num_gpus = mlutils.get_test_num_gpus() if pytest.test_device == "GPU" else 1 test_script = fileutils.get_test_conf_path("run_dbscript_smartredis.py") torch_script = fileutils.get_test_conf_path("torchscript.py") - # Create the SmartSim Experiment - exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) - # Create the RunSettings - run_settings = exp.create_run_settings(exe=sys.executable, exe_args=test_script) + run_settings = wlm_experiment.create_run_settings( + exe=sys.executable, exe_args=test_script + ) run_settings.set_nodes(1) run_settings.set_tasks(1) # Create the SmartSim Model - smartsim_model = exp.create_model("smartsim_model", run_settings) + smartsim_model = wlm_experiment.create_model("smartsim_model", run_settings) # Create the SmartSim database - host = wlmutils.choose_host(run_settings) - db = exp.create_database(port=test_port, interface=test_interface, hosts=host) - exp.generate(db, smartsim_model) + db = prepare_db(single_db).orchestrator + wlm_experiment.reconnect_orchestrator(db.checkpoint_file) + wlm_experiment.generate(smartsim_model) # Define the torch script string torch_script_str = "def negate(x):\n\treturn torch.neg(x)\n" @@ -123,51 +115,42 @@ def test_db_script(fileutils, test_dir, wlmutils, mlutils): assert len(smartsim_model._db_scripts) == 3 # Launch and check successful completion - try: - exp.start(db, smartsim_model, block=True) - statuses = exp.get_status(smartsim_model) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) - finally: - exp.stop(db) + wlm_experiment.start(smartsim_model, block=True) + statuses = wlm_experiment.get_status(smartsim_model) + assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) @pytest.mark.skipif(not should_run, reason="Test needs Torch to run") -def test_db_script_ensemble(fileutils, test_dir, wlmutils, mlutils): +def test_db_script_ensemble(wlm_experiment, prepare_db, single_db, fileutils, mlutils): """Test DB scripts on remote DB""" - # Set experiment name - exp_name = "test-db-script" + # Set wlm_experimenteriment name + wlm_experiment_name = "test-db-script" # Retrieve parameters from testing environment - test_launcher = wlmutils.get_test_launcher() - test_interface = wlmutils.get_test_interface() - test_port = wlmutils.get_test_port() test_device = mlutils.get_test_device() test_num_gpus = mlutils.get_test_num_gpus() if pytest.test_device == "GPU" else 1 test_script = fileutils.get_test_conf_path("run_dbscript_smartredis.py") torch_script = fileutils.get_test_conf_path("torchscript.py") - # Create SmartSim Experiment - exp = Experiment(exp_name, exp_path=test_dir, launcher=test_launcher) - # Create RunSettings - run_settings = exp.create_run_settings(exe=sys.executable, exe_args=test_script) + run_settings = wlm_experiment.create_run_settings( + exe=sys.executable, exe_args=test_script + ) run_settings.set_nodes(1) run_settings.set_tasks(1) + db = prepare_db(single_db).orchestrator + wlm_experiment.reconnect_orchestrator(db.checkpoint_file) + # Create Ensemble with two identical models - ensemble = exp.create_ensemble( + ensemble = wlm_experiment.create_ensemble( "dbscript_ensemble", run_settings=run_settings, replicas=2 ) # Create SmartSim model - smartsim_model = exp.create_model("smartsim_model", run_settings) - - # Create SmartSim database - host = wlmutils.choose_host(run_settings) - db = exp.create_database(port=test_port, interface=test_interface, hosts=host) - exp.generate(db) + smartsim_model = wlm_experiment.create_model("smartsim_model", run_settings) # Create the script string torch_script_str = "def negate(x):\n\treturn torch.neg(x)\n" @@ -217,14 +200,11 @@ def test_db_script_ensemble(fileutils, test_dir, wlmutils, mlutils): # Assert we have added all three models to entities in ensemble assert all([len(entity._db_scripts) == 3 for entity in ensemble]) - exp.generate(ensemble) + wlm_experiment.generate(ensemble) - try: - exp.start(db, ensemble, block=True) - statuses = exp.get_status(ensemble) - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) - finally: - exp.stop(db) + wlm_experiment.start(ensemble, block=True) + statuses = wlm_experiment.get_status(ensemble) + assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) @pytest.mark.skipif(not should_run, reason="Test needs Torch to run") diff --git a/tests/backends/test_onnx.py b/tests/backends/test_onnx.py index 871c3f059..29771bb1c 100644 --- a/tests/backends/test_onnx.py +++ b/tests/backends/test_onnx.py @@ -57,7 +57,7 @@ ) -def test_sklearn_onnx(test_dir, mlutils, wlmutils): +def test_sklearn_onnx(wlm_experiment, prepare_db, single_db, mlutils, wlmutils): """This test needs two free nodes, 1 for the db and 1 some sklearn models here we test the following sklearn models: @@ -74,33 +74,24 @@ def test_sklearn_onnx(test_dir, mlutils, wlmutils): You may need to put CUDNN in your LD_LIBRARY_PATH if running on GPU """ - - exp_name = "test_sklearn_onnx" - - exp = Experiment(exp_name, exp_path=test_dir, launcher=wlmutils.get_test_launcher()) test_device = mlutils.get_test_device() + db = prepare_db(single_db).orchestrator + wlm_experiment.reconnect_orchestrator(db.checkpoint_file) - db = wlmutils.get_orchestrator(nodes=1) - db.set_path(test_dir) - - exp.start(db) - try: - run_settings = exp.create_run_settings( - sys.executable, f"run_sklearn_onnx.py --device={test_device}" - ) - if wlmutils.get_test_launcher() != "local": - run_settings.set_tasks(1) - model = exp.create_model("onnx_models", run_settings) + run_settings = wlm_experiment.create_run_settings( + sys.executable, f"run_sklearn_onnx.py --device={test_device}" + ) + if wlmutils.get_test_launcher() != "local": + run_settings.set_tasks(1) + model = wlm_experiment.create_model("onnx_models", run_settings) - script_dir = os.path.dirname(os.path.abspath(__file__)) - script_path = Path(script_dir, "run_sklearn_onnx.py").resolve() - model.attach_generator_files(to_copy=str(script_path)) - exp.generate(model) + script_dir = os.path.dirname(os.path.abspath(__file__)) + script_path = Path(script_dir, "run_sklearn_onnx.py").resolve() + model.attach_generator_files(to_copy=str(script_path)) + wlm_experiment.generate(model) - exp.start(model, block=True) - finally: - exp.stop(db) + wlm_experiment.start(model, block=True) # if model failed, test will fail - model_status = exp.get_status(model) + model_status = wlm_experiment.get_status(model) assert model_status[0] != SmartSimStatus.STATUS_FAILED diff --git a/tests/backends/test_tf.py b/tests/backends/test_tf.py index 92cd01695..adf0e9daa 100644 --- a/tests/backends/test_tf.py +++ b/tests/backends/test_tf.py @@ -50,7 +50,7 @@ (not tf_backend_available) or (not tf_available), reason="Requires RedisAI TF backend", ) -def test_keras_model(test_dir, mlutils, wlmutils): +def test_keras_model(wlm_experiment, prepare_db, single_db, mlutils, wlmutils): """This test needs two free nodes, 1 for the db and 1 for a keras model script this test can run on CPU/GPU by setting SMARTSIM_TEST_DEVICE=GPU @@ -60,33 +60,27 @@ def test_keras_model(test_dir, mlutils, wlmutils): You may need to put CUDNN in your LD_LIBRARY_PATH if running on GPU """ - exp_name = "test_keras_model" - - exp = Experiment(exp_name, exp_path=test_dir, launcher=wlmutils.get_test_launcher()) test_device = mlutils.get_test_device() + db = prepare_db(single_db).orchestrator + wlm_experiment.reconnect_orchestrator(db.checkpoint_file) - db = wlmutils.get_orchestrator(nodes=1) - db.set_path(test_dir) - exp.start(db) - - run_settings = exp.create_run_settings( + run_settings = wlm_experiment.create_run_settings( "python", f"run_tf.py --device={test_device}" ) if wlmutils.get_test_launcher() != "local": run_settings.set_tasks(1) - model = exp.create_model("tf_script", run_settings) + model = wlm_experiment.create_model("tf_script", run_settings) script_dir = os.path.dirname(os.path.abspath(__file__)) script_path = Path(script_dir, "run_tf.py").resolve() model.attach_generator_files(to_copy=str(script_path)) - exp.generate(model) + wlm_experiment.generate(model) - exp.start(model, block=True) + wlm_experiment.start(model, block=True) - exp.stop(db) # if model failed, test will fail - model_status = exp.get_status(model)[0] + model_status = wlm_experiment.get_status(model)[0] assert model_status != SmartSimStatus.STATUS_FAILED diff --git a/tests/backends/test_torch.py b/tests/backends/test_torch.py index a36037de4..c995f76ca 100644 --- a/tests/backends/test_torch.py +++ b/tests/backends/test_torch.py @@ -48,7 +48,9 @@ ) -def test_torch_model_and_script(test_dir, mlutils, wlmutils): +def test_torch_model_and_script( + wlm_experiment, prepare_db, single_db, mlutils, wlmutils +): """This test needs two free nodes, 1 for the db and 1 for a torch model script Here we test both the torchscipt API and the NN API from torch @@ -60,30 +62,24 @@ def test_torch_model_and_script(test_dir, mlutils, wlmutils): You may need to put CUDNN in your LD_LIBRARY_PATH if running on GPU """ - exp_name = "test_torch_model_and_script" - - exp = Experiment(exp_name, exp_path=test_dir, launcher=wlmutils.get_test_launcher()) + db = prepare_db(single_db).orchestrator + wlm_experiment.reconnect_orchestrator(db.checkpoint_file) test_device = mlutils.get_test_device() - db = wlmutils.get_orchestrator(nodes=1) - db.set_path(test_dir) - exp.start(db) - - run_settings = exp.create_run_settings( + run_settings = wlm_experiment.create_run_settings( "python", f"run_torch.py --device={test_device}" ) if wlmutils.get_test_launcher() != "local": run_settings.set_tasks(1) - model = exp.create_model("torch_script", run_settings) + model = wlm_experiment.create_model("torch_script", run_settings) script_dir = os.path.dirname(os.path.abspath(__file__)) script_path = Path(script_dir, "run_torch.py").resolve() model.attach_generator_files(to_copy=str(script_path)) - exp.generate(model) + wlm_experiment.generate(model) - exp.start(model, block=True) + wlm_experiment.start(model, block=True) - exp.stop(db) # if model failed, test will fail - model_status = exp.get_status(model)[0] + model_status = wlm_experiment.get_status(model)[0] assert model_status != SmartSimStatus.STATUS_FAILED diff --git a/tests/on_wlm/test_symlinking.py b/tests/on_wlm/test_symlinking.py index 246457d1c..c5b5b90ba 100644 --- a/tests/on_wlm/test_symlinking.py +++ b/tests/on_wlm/test_symlinking.py @@ -28,8 +28,13 @@ import pathlib import time +import pytest + from smartsim import Experiment +if pytest.test_launcher not in pytest.wlm_options: + pytestmark = pytest.mark.skip(reason="Not testing WLM integrations") + def test_batch_model_and_ensemble(test_dir, wlmutils): exp_name = "test-batch" diff --git a/tests/on_wlm/test_wlm_orc_config_settings.py b/tests/on_wlm/test_wlm_orc_config_settings.py index 7727b0a46..c74f2a497 100644 --- a/tests/on_wlm/test_wlm_orc_config_settings.py +++ b/tests/on_wlm/test_wlm_orc_config_settings.py @@ -43,9 +43,10 @@ pytestmark = pytest.mark.skip(reason="SmartRedis version is < 0.3.1") -def test_config_methods_on_wlm_single(dbutils, db): +def test_config_methods_on_wlm_single(dbutils, prepare_db, single_db): """Test all configuration file edit methods on single node WLM db""" + db = prepare_db(single_db).orchestrator # test the happy path and ensure all configuration file edit methods # successfully execute when given correct key-value pairs configs = dbutils.get_db_configs() @@ -71,15 +72,16 @@ def test_config_methods_on_wlm_single(dbutils, db): db.set_db_conf(key, value) -def test_config_methods_on_wlm_cluster(dbutils, db_cluster): +def test_config_methods_on_wlm_cluster(dbutils, prepare_db, clustered_db): """Test all configuration file edit methods on an active clustered db""" + db = prepare_db(clustered_db).orchestrator # test the happy path and ensure all configuration file edit methods # successfully execute when given correct key-value pairs configs = dbutils.get_db_configs() for setting, value in configs.items(): logger.debug(f"Setting {setting}={value}") - config_set_method = dbutils.get_config_edit_method(db_cluster, setting) + config_set_method = dbutils.get_config_edit_method(db, setting) config_set_method(value) # ensure SmartSimError is raised when a clustered database's @@ -89,7 +91,7 @@ def test_config_methods_on_wlm_cluster(dbutils, db_cluster): for value in value_list: with pytest.raises(SmartSimError): logger.debug(f"Setting {key}={value}") - db_cluster.set_db_conf(key, value) + db.set_db_conf(key, value) # ensure TypeError is raised when a clustered database's # Orchestrator.set_db_conf is given invalid CONFIG key-value pairs @@ -98,4 +100,4 @@ def test_config_methods_on_wlm_cluster(dbutils, db_cluster): for value in value_list: with pytest.raises(TypeError): logger.debug(f"Setting {key}={value}") - db_cluster.set_db_conf(key, value) + db.set_db_conf(key, value) diff --git a/tests/test_collector_manager.py b/tests/test_collector_manager.py index 91f083487..56add1ef7 100644 --- a/tests/test_collector_manager.py +++ b/tests/test_collector_manager.py @@ -246,11 +246,13 @@ async def test_collector_manager_collect_filesink( @pytest.mark.asyncio async def test_collector_manager_collect_integration( - test_dir: str, mock_entity: MockCollectorEntityFunc, local_db, mock_sink + test_dir: str, mock_entity: MockCollectorEntityFunc, prepare_db, local_db, mock_sink ) -> None: """Ensure that all collectors are executed and some metric is retrieved""" - entity1 = mock_entity(port=local_db.ports[0], name="e1", telemetry_on=True) - entity2 = mock_entity(port=local_db.ports[0], name="e2", telemetry_on=True) + + db = prepare_db(local_db).orchestrator + entity1 = mock_entity(port=db.ports[0], name="e1", telemetry_on=True) + entity2 = mock_entity(port=db.ports[0], name="e2", telemetry_on=True) # todo: consider a MockSink so i don't have to save the last value in the collector sinks = [mock_sink(), mock_sink(), mock_sink()] diff --git a/tests/test_collectors.py b/tests/test_collectors.py index f56c89736..2eb61d62d 100644 --- a/tests/test_collectors.py +++ b/tests/test_collectors.py @@ -42,6 +42,8 @@ # The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a +PrepareDB = t.Callable[[dict], smartsim.experiment.Orchestrator] + @pytest.mark.asyncio async def test_dbmemcollector_prepare( @@ -171,12 +173,15 @@ async def test_dbmemcollector_collect( async def test_dbmemcollector_integration( mock_entity: MockCollectorEntityFunc, mock_sink: MockSink, - local_db: smartsim.experiment.Orchestrator, + prepare_db: PrepareDB, + local_db: dict, monkeypatch: pytest.MonkeyPatch, ) -> None: """Integration test with a real orchestrator instance to ensure output data matches expectations and proper db client API uage""" - entity = mock_entity(port=local_db.ports[0], telemetry_on=True) + + db = prepare_db(local_db).orchestrator + entity = mock_entity(port=db.ports[0], telemetry_on=True) sink = mock_sink() collector = DBMemoryCollector(entity, sink) @@ -268,12 +273,15 @@ async def test_dbconn_count_collector_collect( async def test_dbconncollector_integration( mock_entity: MockCollectorEntityFunc, mock_sink: MockSink, - local_db: smartsim.experiment.Orchestrator, + prepare_db: PrepareDB, + local_db: dict, monkeypatch: pytest.MonkeyPatch, ) -> None: """Integration test with a real orchestrator instance to ensure output data matches expectations and proper db client API uage""" - entity = mock_entity(port=local_db.ports[0], telemetry_on=True) + + db = prepare_db(local_db).orchestrator + entity = mock_entity(port=db.ports[0], telemetry_on=True) sink = mock_sink() collector = DBConnectionCollector(entity, sink) diff --git a/tests/test_configs/smartredis/multidbid_colo_env_vars_only.py b/tests/test_configs/smartredis/multidbid_colo_env_vars_only.py new file mode 100644 index 000000000..74a15c010 --- /dev/null +++ b/tests/test_configs/smartredis/multidbid_colo_env_vars_only.py @@ -0,0 +1,52 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2023, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import os + +from smartredis import Client, ConfigOptions + +if __name__ == "__main__": + """For inclusion in test with two unique database identifiers with multiple + databases where one (presumably colocated) database is started before the + other, and thus only one DB ID is known at application runtime and + available via environment variable. + """ + + parser = argparse.ArgumentParser(description="SmartRedis") + parser.add_argument("--exchange", action="store_true") + parser.add_argument("--should-see-reg-db", action="store_true") + args = parser.parse_args() + + env_vars = [ + "SSDB_testdb_colo", + "SR_DB_TYPE_testdb_colo", + ] + + assert all([var in os.environ for var in env_vars]) + + opts = ConfigOptions.create_from_environment("testdb_colo") + Client(opts, logger_name="SmartSim") diff --git a/tests/test_containers.py b/tests/test_containers.py index 98fa5e1bb..5d0f933ff 100644 --- a/tests/test_containers.py +++ b/tests/test_containers.py @@ -31,8 +31,7 @@ import pytest -from smartsim import Experiment -from smartsim.database import Orchestrator +from smartsim import Experiment, status from smartsim.entity import Ensemble from smartsim.settings.containers import Singularity from smartsim.status import SmartSimStatus @@ -143,7 +142,7 @@ def test_singularity_args(fileutils, test_dir): @pytest.mark.skipif(not singularity_exists, reason="Test needs singularity to run") -def test_singularity_smartredis(test_dir, fileutils, wlmutils): +def test_singularity_smartredis(local_experiment, prepare_db, local_db, fileutils): """Run two processes, each process puts a tensor on the DB, then accesses the other process's tensor. Finally, the tensor is used to run a model. @@ -151,18 +150,13 @@ def test_singularity_smartredis(test_dir, fileutils, wlmutils): Note: This is a containerized port of test_smartredis.py """ - exp = Experiment( - "smartredis_ensemble_exchange", exp_path=test_dir, launcher="local" - ) - # create and start a database - orc = Orchestrator(port=wlmutils.get_test_port()) - exp.generate(orc) - exp.start(orc, block=False) + db = prepare_db(local_db).orchestrator + local_experiment.reconnect_orchestrator(db.checkpoint_file) container = Singularity(containerURI) - rs = exp.create_run_settings( + rs = local_experiment.create_run_settings( "python3", "producer.py --exchange", container=container ) params = {"mult": [1, -10]} @@ -179,18 +173,12 @@ def test_singularity_smartredis(test_dir, fileutils, wlmutils): config = fileutils.get_test_conf_path("smartredis") ensemble.attach_generator_files(to_copy=[config]) - exp.generate(ensemble) + local_experiment.generate(ensemble) # start the models - exp.start(ensemble, summary=False) + local_experiment.start(ensemble, summary=False) # get and confirm statuses - statuses = exp.get_status(ensemble) + statuses = local_experiment.get_status(ensemble) if not all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]): - exp.stop(orc) assert False # client ensemble failed - - # stop the orchestrator - exp.stop(orc) - - print(exp.summary()) diff --git a/tests/test_dbnode.py b/tests/test_dbnode.py index 227572ac9..04845344c 100644 --- a/tests/test_dbnode.py +++ b/tests/test_dbnode.py @@ -49,22 +49,12 @@ def test_parse_db_host_error(): orc.entities[0].host -def test_hosts(test_dir, wlmutils): - exp_name = "test_hosts" - exp = Experiment(exp_name, exp_path=test_dir) - - orc = Orchestrator(port=wlmutils.get_test_port(), interface="lo", launcher="local") - orc.set_path(test_dir) - exp.start(orc) - - hosts = [] - try: - hosts = orc.hosts - assert len(hosts) == orc.db_nodes == 1 - finally: - # stop the database even if there is an error raised - exp.stop(orc) - orc.remove_stale_files() +def test_hosts(local_experiment, prepare_db, local_db): + db = prepare_db(local_db).orchestrator + orc = local_experiment.reconnect_orchestrator(db.checkpoint_file) + + hosts = orc.hosts + assert len(hosts) == orc.db_nodes == 1 def _random_shard_info(): diff --git a/tests/test_fixtures.py b/tests/test_fixtures.py new file mode 100644 index 000000000..ea753374e --- /dev/null +++ b/tests/test_fixtures.py @@ -0,0 +1,56 @@ +# BSD 2-Clause License +# +# Copyright (c) 2021-2024, Hewlett Packard Enterprise +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import os + +import psutil +import pytest + +from smartsim import Experiment +from smartsim.database import Orchestrator +from smartsim.error import SmartSimError +from smartsim.error.errors import SSUnsupportedError + +# The tests in this file belong to the group_a group +pytestmark = pytest.mark.group_a + + +def test_db_fixtures(local_experiment, local_db, prepare_db): + db = prepare_db(local_db).orchestrator + local_experiment.reconnect_orchestrator(db.checkpoint_file) + assert db.is_active() + local_experiment.stop(db) + + +def test_create_new_db_fixture_if_stopped(local_experiment, local_db, prepare_db): + # Run this twice to make sure that there is a stopped database + output = prepare_db(local_db) + local_experiment.reconnect_orchestrator(output.orchestrator.checkpoint_file) + local_experiment.stop(output.orchestrator) + + output = prepare_db(local_db) + assert output.new_db + local_experiment.reconnect_orchestrator(output.orchestrator.checkpoint_file) + assert output.orchestrator.is_active() diff --git a/tests/test_multidb.py b/tests/test_multidb.py index 94ac13198..81f21856a 100644 --- a/tests/test_multidb.py +++ b/tests/test_multidb.py @@ -403,7 +403,9 @@ def test_multidb_colo_then_standard(fileutils, test_dir, wlmutils, coloutils, db # Retrieve parameters from testing environment test_port = wlmutils.get_test_port() - test_script = fileutils.get_test_conf_path("smartredis/multidbid.py") + test_script = fileutils.get_test_conf_path( + "smartredis/multidbid_colo_env_vars_only.py" + ) test_interface = wlmutils.get_test_interface() test_launcher = wlmutils.get_test_launcher() @@ -433,8 +435,9 @@ def test_multidb_colo_then_standard(fileutils, test_dir, wlmutils, coloutils, db ) with make_entity_context(exp, db), make_entity_context(exp, smartsim_model): + exp.start(smartsim_model, block=False) exp.start(db) - exp.start(smartsim_model, block=True) + exp.poll(smartsim_model) check_not_failed(exp, db, smartsim_model) diff --git a/tests/test_orc_config_settings.py b/tests/test_orc_config_settings.py index 365596496..74d0c1af2 100644 --- a/tests/test_orc_config_settings.py +++ b/tests/test_orc_config_settings.py @@ -27,6 +27,7 @@ import pytest +from smartsim.database import Orchestrator from smartsim.error import SmartSimError try: @@ -40,14 +41,15 @@ pytestmark = pytest.mark.group_b -def test_config_methods(dbutils, local_db): +def test_config_methods(dbutils, prepare_db, local_db): """Test all configuration file edit methods on an active db""" + db = prepare_db(local_db).orchestrator # test the happy path and ensure all configuration file edit methods # successfully execute when given correct key-value pairs configs = dbutils.get_db_configs() for setting, value in configs.items(): - config_set_method = dbutils.get_config_edit_method(local_db, setting) + config_set_method = dbutils.get_config_edit_method(db, setting) config_set_method(value) # ensure SmartSimError is raised when Orchestrator.set_db_conf @@ -56,7 +58,7 @@ def test_config_methods(dbutils, local_db): for key, value_list in ss_error_configs.items(): for value in value_list: with pytest.raises(SmartSimError): - local_db.set_db_conf(key, value) + db.set_db_conf(key, value) # ensure TypeError is raised when Orchestrator.set_db_conf # is given either a key or a value that is not a string @@ -64,14 +66,14 @@ def test_config_methods(dbutils, local_db): for key, value_list in type_error_configs.items(): for value in value_list: with pytest.raises(TypeError): - local_db.set_db_conf(key, value) + db.set_db_conf(key, value) -def test_config_methods_inactive(wlmutils, dbutils): +def test_config_methods_inactive(dbutils): """Ensure a SmartSimError is raised when trying to set configurations on an inactive database """ - db = wlmutils.get_orchestrator() + db = Orchestrator() configs = dbutils.get_db_configs() for setting, value in configs.items(): config_set_method = dbutils.get_config_edit_method(db, setting) diff --git a/tests/test_orchestrator.py b/tests/test_orchestrator.py index 54f86ad99..66fb894f7 100644 --- a/tests/test_orchestrator.py +++ b/tests/test_orchestrator.py @@ -74,34 +74,22 @@ def test_inactive_orc_get_address() -> None: db.get_address() -def test_orc_active_functions(test_dir: str, wlmutils: "conftest.WLMUtils") -> None: - exp_name = "test_orc_active_functions" - exp = Experiment(exp_name, launcher="local", exp_path=test_dir) - - db = Orchestrator(port=wlmutils.get_test_port()) - db.set_path(test_dir) - - exp.start(db) - - # check if the orchestrator is active +def test_orc_is_active_functions( + local_experiment, + prepare_db, + local_db, +) -> None: + db = prepare_db(local_db).orchestrator + db = local_experiment.reconnect_orchestrator(db.checkpoint_file) assert db.is_active() # check if the orchestrator can get the address - correct_address = db.get_address() == ["127.0.0.1:" + str(wlmutils.get_test_port())] - if not correct_address: - exp.stop(db) - assert False + assert db.get_address() == [f"127.0.0.1:{db.ports[0]}"] - exp.stop(db) - assert not db.is_active() - - # check if orchestrator.get_address() raises an exception - with pytest.raises(SmartSimError): - db.get_address() - - -def test_multiple_interfaces(test_dir: str, wlmutils: "conftest.WLMUtils") -> None: +def test_multiple_interfaces( + test_dir: str, wlmutils: t.Type["conftest.WLMUtils"] +) -> None: exp_name = "test_multiple_interfaces" exp = Experiment(exp_name, launcher="local", exp_path=test_dir) @@ -112,7 +100,8 @@ def test_multiple_interfaces(test_dir: str, wlmutils: "conftest.WLMUtils") -> No net_if_addrs = ["lo", net_if_addrs[0]] - db = Orchestrator(port=wlmutils.get_test_port(), interface=net_if_addrs) + port = wlmutils.get_test_port() + db = Orchestrator(port=port, interface=net_if_addrs) db.set_path(test_dir) exp.start(db) @@ -121,8 +110,9 @@ def test_multiple_interfaces(test_dir: str, wlmutils: "conftest.WLMUtils") -> No assert db.is_active() # check if the orchestrator can get the address - correct_address = db.get_address() == ["127.0.0.1:" + str(wlmutils.get_test_port())] - if not correct_address: + correct_address = [f"127.0.0.1:{port}"] + + if not correct_address == db.get_address(): exp.stop(db) assert False @@ -146,7 +136,7 @@ def test_catch_local_db_errors() -> None: ##### PBS ###### -def test_pbs_set_run_arg(wlmutils: "conftest.WLMUtils") -> None: +def test_pbs_set_run_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: orc = Orchestrator( wlmutils.get_test_port(), db_nodes=3, @@ -165,7 +155,7 @@ def test_pbs_set_run_arg(wlmutils: "conftest.WLMUtils") -> None: ) -def test_pbs_set_batch_arg(wlmutils: "conftest.WLMUtils") -> None: +def test_pbs_set_batch_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: orc = Orchestrator( wlmutils.get_test_port(), db_nodes=3, @@ -194,7 +184,7 @@ def test_pbs_set_batch_arg(wlmutils: "conftest.WLMUtils") -> None: ##### Slurm ###### -def test_slurm_set_run_arg(wlmutils: "conftest.WLMUtils") -> None: +def test_slurm_set_run_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: orc = Orchestrator( wlmutils.get_test_port(), db_nodes=3, @@ -209,7 +199,7 @@ def test_slurm_set_run_arg(wlmutils: "conftest.WLMUtils") -> None: ) -def test_slurm_set_batch_arg(wlmutils: "conftest.WLMUtils") -> None: +def test_slurm_set_batch_arg(wlmutils: t.Type["conftest.WLMUtils"]) -> None: orc = Orchestrator( wlmutils.get_test_port(), db_nodes=3, @@ -265,7 +255,7 @@ def test_orc_results_in_correct_number_of_shards(single_cmd: bool) -> None: ###### LSF ###### -def test_catch_orc_errors_lsf(wlmutils: "conftest.WLMUtils") -> None: +def test_catch_orc_errors_lsf(wlmutils: t.Type["conftest.WLMUtils"]) -> None: with pytest.raises(SSUnsupportedError): orc = Orchestrator( wlmutils.get_test_port(), @@ -288,7 +278,7 @@ def test_catch_orc_errors_lsf(wlmutils: "conftest.WLMUtils") -> None: orc.set_batch_arg("P", "MYPROJECT") -def test_lsf_set_run_args(wlmutils: "conftest.WLMUtils") -> None: +def test_lsf_set_run_args(wlmutils: t.Type["conftest.WLMUtils"]) -> None: orc = Orchestrator( wlmutils.get_test_port(), db_nodes=3, @@ -301,7 +291,7 @@ def test_lsf_set_run_args(wlmutils: "conftest.WLMUtils") -> None: assert all(["l" not in db.run_settings.run_args for db in orc.entities]) -def test_lsf_set_batch_args(wlmutils: "conftest.WLMUtils") -> None: +def test_lsf_set_batch_args(wlmutils: t.Type["conftest.WLMUtils"]) -> None: orc = Orchestrator( wlmutils.get_test_port(), db_nodes=3, @@ -316,7 +306,7 @@ def test_lsf_set_batch_args(wlmutils: "conftest.WLMUtils") -> None: assert orc.batch_settings.batch_args["D"] == "102400000" -def test_orc_telemetry(test_dir: str, wlmutils: "conftest.WLMUtils") -> None: +def test_orc_telemetry(test_dir: str, wlmutils: t.Type["conftest.WLMUtils"]) -> None: """Ensure the default behavior for an orchestrator is to disable telemetry""" db = Orchestrator(port=wlmutils.get_test_port()) db.set_path(test_dir) diff --git a/tests/test_smartredis.py b/tests/test_smartredis.py index a2aac654b..6f7b19934 100644 --- a/tests/test_smartredis.py +++ b/tests/test_smartredis.py @@ -60,22 +60,17 @@ ) -def test_exchange(fileutils, test_dir, wlmutils): +def test_exchange(local_experiment, local_db, prepare_db, fileutils): """Run two processes, each process puts a tensor on the DB, then accesses the other process's tensor. Finally, the tensor is used to run a model. """ - exp = Experiment( - "smartredis_ensemble_exchange", exp_path=test_dir, launcher="local" - ) - + db = prepare_db(local_db).orchestrator # create and start a database - orc = Orchestrator(port=wlmutils.get_test_port()) - exp.generate(orc) - exp.start(orc, block=False) + local_experiment.reconnect_orchestrator(db.checkpoint_file) - rs = exp.create_run_settings("python", "producer.py --exchange") + rs = local_experiment.create_run_settings("python", "producer.py --exchange") params = {"mult": [1, -10]} ensemble = Ensemble( name="producer", @@ -90,21 +85,17 @@ def test_exchange(fileutils, test_dir, wlmutils): config = fileutils.get_test_conf_path("smartredis") ensemble.attach_generator_files(to_copy=[config]) - exp.generate(ensemble) + local_experiment.generate(ensemble) # start the models - exp.start(ensemble, summary=False) + local_experiment.start(ensemble, summary=False) # get and confirm statuses - statuses = exp.get_status(ensemble) - try: - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) - finally: - # stop the orchestrator - exp.stop(orc) + statuses = local_experiment.get_status(ensemble) + assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) -def test_consumer(fileutils, test_dir, wlmutils): +def test_consumer(local_experiment, local_db, prepare_db, fileutils): """Run three processes, each one of the first two processes puts a tensor on the DB; the third process accesses the tensors put by the two producers. @@ -112,17 +103,11 @@ def test_consumer(fileutils, test_dir, wlmutils): and the consumer accesses the two results. """ - exp = Experiment( - "smartredis_ensemble_consumer", exp_path=test_dir, launcher="local" - ) - - # create and start a database - orc = Orchestrator(port=wlmutils.get_test_port()) - exp.generate(orc) - exp.start(orc, block=False) + db = prepare_db(local_db).orchestrator + local_experiment.reconnect_orchestrator(db.checkpoint_file) - rs_prod = exp.create_run_settings("python", "producer.py") - rs_consumer = exp.create_run_settings("python", "consumer.py") + rs_prod = local_experiment.create_run_settings("python", "producer.py") + rs_consumer = local_experiment.create_run_settings("python", "consumer.py") params = {"mult": [1, -10]} ensemble = Ensemble( name="producer", params=params, run_settings=rs_prod, perm_strat="step" @@ -139,15 +124,11 @@ def test_consumer(fileutils, test_dir, wlmutils): config = fileutils.get_test_conf_path("smartredis") ensemble.attach_generator_files(to_copy=[config]) - exp.generate(ensemble) + local_experiment.generate(ensemble) # start the models - exp.start(ensemble, summary=False) + local_experiment.start(ensemble, summary=False) # get and confirm statuses - statuses = exp.get_status(ensemble) - try: - assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses]) - finally: - # stop the orchestrator - exp.stop(orc) + statuses = local_experiment.get_status(ensemble) + assert all([stat == SmartSimStatus.STATUS_COMPLETED for stat in statuses])