Skip to content

Commit

Permalink
Raising error for inconsistent add_ml_model and add_script parameters (
Browse files Browse the repository at this point in the history
…#324)

Added error when user requests CPU with devices >1 within add_ml_model and add_script methods. Included more strict typing and tests.  

[ committed by @juliaputko ]
[ reviewed by @ashao, @ankona ]
  • Loading branch information
juliaputko authored Jul 29, 2023
1 parent 9d7ac35 commit 53bff05
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 30 deletions.
7 changes: 2 additions & 5 deletions smartsim/_core/utils/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,8 @@ def db_is_active(hosts: t.List[str], ports: t.List[int], num_shards: int) -> boo

def set_ml_model(db_model: DBModel, client: Client) -> None:
logger.debug(f"Adding DBModel named {db_model.name}")
devices = db_model._enumerate_devices() # pylint: disable=protected-access

for device in devices:
for device in db_model.devices:
try:
if db_model.is_file:
client.set_model_from_file(
Expand Down Expand Up @@ -194,9 +193,7 @@ def set_ml_model(db_model: DBModel, client: Client) -> None:
def set_script(db_script: DBScript, client: Client) -> None:
logger.debug(f"Adding DBScript named {db_script.name}")

devices = db_script._enumerate_devices() # pylint: disable=protected-access

for device in devices:
for device in db_script.devices:
try:
if db_script.is_file:
client.set_script_from_file(
Expand Down
54 changes: 35 additions & 19 deletions smartsim/entity/dbobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

from pathlib import Path
from .._core.utils import init_default
from ..error import SSUnsupportedError


__all__ = ["DBObject", "DBModel", "DBScript"]
Expand All @@ -43,7 +44,7 @@ def __init__(
name: str,
func: t.Optional[str],
file_path: t.Optional[str],
device: str,
device: t.Literal["CPU", "GPU"],
devices_per_node: int,
) -> None:
self.name = name
Expand All @@ -55,6 +56,11 @@ def __init__(
self.file = self._check_filepath(file_path)
self.device = self._check_device(device)
self.devices_per_node = devices_per_node
self._check_devices(device, devices_per_node)

@property
def devices(self) -> t.List[str]:
return self._enumerate_devices()

@property
def is_file(self) -> bool:
Expand Down Expand Up @@ -95,8 +101,8 @@ def _check_filepath(file: str) -> Path:
return file_path

@staticmethod
def _check_device(device: str) -> str:
device = device.upper()
def _check_device(device: t.Literal["CPU", "GPU"]) -> str:
device = t.cast(t.Literal["CPU", "GPU"], device.upper())
if not device.startswith("CPU") and not device.startswith("GPU"):
raise ValueError("Device argument must start with either CPU or GPU")
return device
Expand All @@ -109,21 +115,31 @@ def _enumerate_devices(self) -> t.List[str]:
:return: list of device names
:rtype: list[str]
"""
devices = []
if ":" in self.device and self.devices_per_node > 1:
msg = (
"Cannot set devices_per_node>1 if a device numeral is specified, "
f"the device was set to {self.device} and "
f"devices_per_node=={self.devices_per_node}"
)
raise ValueError(msg)
if self.device in ["CPU", "GPU"] and self.devices_per_node > 1:
for device_num in range(self.devices_per_node):
devices.append(f"{self.device}:{str(device_num)}")
else:
devices = [self.device]

return devices
if self.device == "GPU" and self.devices_per_node > 1:
return [
f"{self.device}:{str(device_num)}"
for device_num in range(self.devices_per_node)
]

return [self.device]

@staticmethod
def _check_devices(
device: t.Literal["CPU", "GPU"], devices_per_node: int
) -> None:
if devices_per_node == 1:
return

if ":" in device:
msg = "Cannot set devices_per_node>1 if a device numeral is specified, "
msg += f"the device was set to {device} and \
devices_per_node=={devices_per_node}"
raise ValueError(msg)
if device == "CPU":
raise SSUnsupportedError(
"Cannot set devices_per_node>1 if CPU is specified under devices"
)


class DBScript(DBObject):
Expand All @@ -132,7 +148,7 @@ def __init__(
name: str,
script: t.Optional[str] = None,
script_path: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU", "GPU"] = "CPU",
devices_per_node: int = 1,
):
"""TorchScript code represenation
Expand Down Expand Up @@ -185,7 +201,7 @@ def __init__(
backend: str,
model: t.Optional[str] = None,
model_file: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU", "GPU"] = "CPU",
devices_per_node: int = 1,
batch_size: int = 0,
min_batch_size: int = 0,
Expand Down
6 changes: 3 additions & 3 deletions smartsim/entity/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def add_ml_model(
backend: str,
model: t.Optional[str] = None,
model_path: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU","GPU"] = "CPU",
devices_per_node: int = 1,
batch_size: int = 0,
min_batch_size: int = 0,
Expand Down Expand Up @@ -395,7 +395,7 @@ def add_script(
name: str,
script: t.Optional[str] = None,
script_path: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU","GPU"] = "CPU",
devices_per_node: int = 1,
) -> None:
"""TorchScript to launch with every entity belonging to this ensemble
Expand Down Expand Up @@ -439,7 +439,7 @@ def add_function(
self,
name: str,
function: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU","GPU"] = "CPU",
devices_per_node: int = 1,
) -> None:
"""TorchScript function to launch with every entity belonging to this ensemble
Expand Down
6 changes: 3 additions & 3 deletions smartsim/entity/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def add_ml_model(
backend: str,
model: t.Optional[str] = None,
model_path: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU","GPU"] = "CPU",
devices_per_node: int = 1,
batch_size: int = 0,
min_batch_size: int = 0,
Expand Down Expand Up @@ -467,7 +467,7 @@ def add_script(
name: str,
script: t.Optional[str] = None,
script_path: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU","GPU"] = "CPU",
devices_per_node: int = 1,
) -> None:
"""TorchScript to launch with this Model instance
Expand Down Expand Up @@ -511,7 +511,7 @@ def add_function(
self,
name: str,
function: t.Optional[str] = None,
device: str = "CPU",
device: t.Literal["CPU","GPU"] = "CPU",
devices_per_node: int = 1,
) -> None:
"""TorchScript function to launch with this Model instance
Expand Down
23 changes: 23 additions & 0 deletions tests/backends/test_dbmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from smartsim.error.errors import SSUnsupportedError
from smartsim.log import get_logger

from smartsim.entity.dbobject import DBModel

logger = get_logger(__name__)

should_run_tf = True
Expand Down Expand Up @@ -793,3 +795,24 @@ def test_colocated_db_model_errors(fileutils, wlmutils, mlutils):

with pytest.raises(SSUnsupportedError):
colo_ensemble.add_model(colo_model)

def test_inconsistent_params_db_model():
"""Test error when devices_per_node parameter>1 when devices is set to CPU in DBModel"""

# Create and save ML model to filesystem
model, inputs, outputs = create_tf_cnn()
with pytest.raises(SSUnsupportedError) as ex:
db_model = DBModel(
"cnn",
"TF",
model=model,
device="CPU",
devices_per_node=2,
tag="test",
inputs=inputs,
outputs=outputs,
)
assert (
ex.value.args[0]
== "Cannot set devices_per_node>1 if CPU is specified under devices"
)
18 changes: 18 additions & 0 deletions tests/backends/test_dbscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from smartsim.error.errors import SSUnsupportedError
from smartsim.log import get_logger

from smartsim.entity.dbobject import DBScript

logger = get_logger(__name__)

should_run = True
Expand Down Expand Up @@ -578,3 +580,19 @@ def test_db_script_errors(fileutils, wlmutils, mlutils):
# an in-memory script
with pytest.raises(SSUnsupportedError):
colo_ensemble.add_model(colo_model)

def test_inconsistent_params_db_script(fileutils):
"""Test error when devices_per_node>1 and when devices is set to CPU in DBScript constructor"""

torch_script = fileutils.get_test_conf_path("torchscript.py")
with pytest.raises(SSUnsupportedError) as ex:
db_script = DBScript(
name="test_script_db",
script_path = torch_script,
device="CPU",
devices_per_node=2,
)
assert (
ex.value.args[0]
== "Cannot set devices_per_node>1 if CPU is specified under devices"
)

0 comments on commit 53bff05

Please sign in to comment.