-
Notifications
You must be signed in to change notification settings - Fork 37
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Raise error for inconsistent add_ml_model and add_script parameters #324
Conversation
…_script bad params
smartsim/entity/dbobject.py
Outdated
for device_num in range(self.devices_per_node): | ||
devices.append(f"{self.device}:{str(device_num)}") | ||
else: | ||
devices = [self.device] | ||
|
||
return devices | ||
|
||
def _check_arguments(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i love pulling complex validation out into separate methods but _check_arguments
feels like it could be unclear which arguments are being checked (especially since it has none).
I'd probably rename this to _check_devices
to make it clear where it's valid to use. I'd also probably pass the inputs instead of allowing the object to maybe be partially initialized.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could see this as being a general sanity checker of the initialization. There's a bit of a tradeoff between cluttering up the constructor and breaking it out into a validator for just devices and number of devices.
As currently coded up (where it only relies on the set properties after initialization is completed), it's unambiguous that the settings are correct. Thoughts @juliaputko @ankona?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like _check_xxx
methods in the constructor follow the pattern of offloading per-field validation to another method. I wouldn't care if they're combined but _check_parameters
is definitely improper naming given the existing:
self.file = self._check_filepath(file_path)
smartsim/entity/dbobject.py
Outdated
msg = "Cannot set devices_per_node>1 if a device numeral is specified, " | ||
msg += f"the device was set to {self.device} and devices_per_node=={self.devices_per_node}" | ||
raise ValueError(msg) | ||
if self.device in ["CPU"] and self.devices_per_node > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need for in
?
smartsim/entity/dbobject.py
Outdated
for device_num in range(self.devices_per_node): | ||
devices.append(f"{self.device}:{str(device_num)}") | ||
else: | ||
devices = [self.device] | ||
|
||
return devices | ||
|
||
def _check_arguments(self): | ||
devices = [] | ||
if ":" in self.device and self.devices_per_node > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we do a check at the start:
if self.devices_per_node <= 1:
return
We can avoid having compound conditionals below, making it easier to read what's required to trigger a failure.
tests/backends/test_dbscript.py
Outdated
@@ -578,3 +578,60 @@ def test_db_script_errors(fileutils, wlmutils, mlutils): | |||
# an in-memory script | |||
with pytest.raises(SSUnsupportedError): | |||
colo_ensemble.add_model(colo_model) | |||
|
|||
def test_inconsistent_params_add_script(fileutils, wlmutils, mlutils): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like we're testing a lot of extra code here. Aren't the new validators in a constructor?
Can we build tests around the smallest unit?
...
# you could vary x, y, z and exercise the validation code
# more directly instead of allowing other code to potentially break
with pytest.raises(SSUnsupportedError):
dbscript = DBScript(x,y,z)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point especially because the add_model
, add_script
, add_function
are essentially just factory methods for their respective DBEntity
if self.device in ["CPU"] and self.devices_per_node > 1: | ||
raise SSUnsupportedError( | ||
"Cannot set devices_per_node>1 if CPU is specified under devices" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to check device?
if not self._check_device(self.device):
raise ValueError("invalid device...")
I think we have a hole since device is an attribute. Changes after the constructor won't be validated.
Consider adding properties that use _check_device
on sets!
smartsim/entity/dbobject.py
Outdated
@@ -53,6 +54,7 @@ def __init__( | |||
self.file = self._check_filepath(file_path) | |||
self.device = self._check_device(device) | |||
self.devices_per_node = devices_per_node | |||
self._check_arguments() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider using device: Literal["CPU", "GPU"]
type hint to tighten the arguments for device (instead of str
) in _check_device
/__init__
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this up!
A few typos and some excellent suggestions from @ankona to address
smartsim/entity/dbobject.py
Outdated
for device_num in range(self.devices_per_node): | ||
devices.append(f"{self.device}:{str(device_num)}") | ||
else: | ||
devices = [self.device] | ||
|
||
return devices | ||
|
||
def _check_arguments(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could see this as being a general sanity checker of the initialization. There's a bit of a tradeoff between cluttering up the constructor and breaking it out into a validator for just devices and number of devices.
As currently coded up (where it only relies on the set properties after initialization is completed), it's unambiguous that the settings are correct. Thoughts @juliaputko @ankona?
smartsim/entity/dbobject.py
Outdated
msg += f"the device was set to {self.device} and devices_per_node=={self.devices_per_node}" | ||
raise ValueError(msg) | ||
if self.device in ["CPU", "GPU"] and self.devices_per_node > 1: | ||
if self.device in ["GPU"] and self.devices_per_node > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just check against GPU (no in
needed). I suppose too we could always make CPU and GPU something like an enum
and attach them to DBObject
so that we're comparing an object of some variety instead of a magic string
tests/backends/test_dbmodel.py
Outdated
@@ -793,3 +793,37 @@ def test_colocated_db_model_errors(fileutils, wlmutils, mlutils): | |||
|
|||
with pytest.raises(SSUnsupportedError): | |||
colo_ensemble.add_model(colo_model) | |||
|
|||
def test_inconsistent_params_add_ml_model(fileutils, wlmutils, mlutils): | |||
"""Test error when devices_per_node parameter>1 when devices is set to CPU in addd_ml_model function""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo addd_ml_model
tests/backends/test_dbscript.py
Outdated
@@ -578,3 +578,60 @@ def test_db_script_errors(fileutils, wlmutils, mlutils): | |||
# an in-memory script | |||
with pytest.raises(SSUnsupportedError): | |||
colo_ensemble.add_model(colo_model) | |||
|
|||
def test_inconsistent_params_add_script(fileutils, wlmutils, mlutils): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point especially because the add_model
, add_script
, add_function
are essentially just factory methods for their respective DBEntity
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## develop #324 +/- ##
===========================================
+ Coverage 87.30% 87.33% +0.02%
===========================================
Files 59 59
Lines 3522 3529 +7
===========================================
+ Hits 3075 3082 +7
Misses 447 447
|
Added error message for Model.add_ml_model and Model.add_script when devices_per_node parameter is >1 and device parameter is set to CPU.
Test for add_ml_model and add_script added to ensure error is correctly thrown.