diff --git a/smartsim/entity/dbobject.py b/smartsim/entity/dbobject.py index be57d8479..9e1dcc348 100644 --- a/smartsim/entity/dbobject.py +++ b/smartsim/entity/dbobject.py @@ -132,21 +132,24 @@ def _enumerate_devices(self) -> t.List[str]: def _check_devices( device: t.Literal["CPU", "GPU"], devices_per_node: int, first_device: int, ) -> None: + if device == "CPU" and devices_per_node > 1: + raise SSUnsupportedError( + "Cannot set devices_per_node>1 if CPU is specified under devices" + ) + + if device == "CPU" and first_device > 0: + raise SSUnsupportedError( + "Cannot set first_device>0 if CPU is specified under devices" + ) + if devices_per_node == 1: return - if first_device < 0: - raise ValueError("Cannot set first_device to a negative number") - if ":" in device: msg = "Cannot set devices_per_node>1 if a device numeral is specified, " msg += f"the device was set to {device} and \ devices_per_node=={devices_per_node}" raise ValueError(msg) - if device == "CPU": - raise SSUnsupportedError( - "Cannot set devices_per_node>1 if CPU is specified under devices" - ) class DBScript(DBObject): diff --git a/tests/backends/test_dbscript.py b/tests/backends/test_dbscript.py index b9aaff64d..c92be31de 100644 --- a/tests/backends/test_dbscript.py +++ b/tests/backends/test_dbscript.py @@ -603,7 +603,7 @@ def test_inconsistent_params_db_script(fileutils): torch_script = fileutils.get_test_conf_path("torchscript.py") with pytest.raises(SSUnsupportedError) as ex: - db_script = DBScript( + _ = DBScript( name="test_script_db", script_path=torch_script, device="CPU", @@ -614,3 +614,15 @@ def test_inconsistent_params_db_script(fileutils): ex.value.args[0] == "Cannot set devices_per_node>1 if CPU is specified under devices" ) + with pytest.raises(SSUnsupportedError) as ex: + _ = DBScript( + name="test_script_db", + script_path=torch_script, + device="CPU", + devices_per_node=1, + first_device=5, + ) + assert ( + ex.value.args[0] + == "Cannot set first_device>0 if CPU is specified under devices" + )