diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 8dc5113d5..aff32604c 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -42,6 +42,7 @@ from smartsim._core.control.launch_history import LaunchHistory from smartsim._core.utils.launcher import LauncherProtocol, create_job_id from smartsim.entity import entity +from smartsim.error import errors from smartsim.experiment import Experiment from smartsim.launchable import job from smartsim.settings import launchSettings @@ -318,10 +319,16 @@ def create(cls, _): def start(self, _): raise NotImplementedError("{type(self).__name__} should not start anything") + def _assert_ids(self, ids: LaunchedJobID): + if any(id_ not in self.id_to_status for id_ in ids): + raise errors.LauncherJobNotFound + def get_status(self, *ids: LaunchedJobID): + self._assert_ids(ids) return {id_: self.id_to_status[id_] for id_ in ids} def stop_jobs(self, *ids: LaunchedJobID): + self._assert_ids(ids) stopped = {id_: JobStatus.CANCELLED for id_ in ids} self.id_to_status |= stopped return stopped