Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added type checking to params on model #676

Merged
merged 7 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ To be released at some future point in time

Description

- Allow specifying Model and Ensemble parameters with
number-like types (e.g. numpy types)
- Pin watchdog to 4.x
- Update codecov to 4.5.0
- Remove build of Redis from setup.py
Expand All @@ -31,6 +33,11 @@ Description

Detailed Notes

- The serializer would fail if a parameter for a Model or Ensemble
was specified as a numpy dtype. The constructors for these
methods now validate that the input is number-like and convert
them to strings
([SmartSim-PR676](https://github.com/CrayLabs/SmartSim/pull/676))
- Pin watchdog to 4.x because v5 introduces new types and requires
updates to the type-checking
([SmartSim-PR690](https://github.com/CrayLabs/SmartSim/pull/690))
Expand Down
22 changes: 21 additions & 1 deletion smartsim/entity/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from __future__ import annotations

import itertools
import numbers
import re
import sys
import typing as t
Expand All @@ -46,6 +47,25 @@
logger = get_logger(__name__)


def _parse_model_parameters(params_dict: t.Dict[str, t.Any]) -> t.Dict[str, str]:
"""Convert the values in a params dict to strings
:raises TypeError: if params are of the wrong type
:return: param dictionary with values and keys cast as strings
"""
param_names: t.List[str] = []
parameters: t.List[str] = []
for name, val in params_dict.items():
param_names.append(name)
if isinstance(val, (str, numbers.Number)):
parameters.append(str(val))
else:
raise TypeError(

Check warning on line 62 in smartsim/entity/model.py

View check run for this annotation

Codecov / codecov/patch

smartsim/entity/model.py#L62

Added line #L62 was not covered by tests
"Incorrect type for model parameters\n"
+ "Must be numeric value or string."
)
return dict(zip(param_names, parameters))


class Model(SmartSimEntity):
def __init__(
self,
Expand All @@ -70,7 +90,7 @@
model as a batch job
"""
super().__init__(name, str(path), run_settings)
self.params = params
self.params = _parse_model_parameters(params)
self.params_as_args = params_as_args
self.incoming_entities: t.List[SmartSimEntity] = []
self._key_prefixing_enabled = False
Expand Down
15 changes: 15 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@

from uuid import uuid4

import numpy as np
import pytest

from smartsim import Experiment
from smartsim._core.control.manifest import LaunchedManifestBuilder
from smartsim._core.launcher.step import SbatchStep, SrunStep
from smartsim.entity import Ensemble, Model
from smartsim.entity.model import _parse_model_parameters
from smartsim.error import EntityExistsError, SSUnsupportedError
from smartsim.settings import RunSettings, SbatchSettings, SrunSettings
from smartsim.settings.mpiSettings import _BaseMPISettings
Expand Down Expand Up @@ -176,3 +178,16 @@ def test_models_batch_settings_are_ignored_in_ensemble(
step_cmd = step.step_cmds[0]
assert any("srun" in tok for tok in step_cmd) # call the model using run settings
assert not any("sbatch" in tok for tok in step_cmd) # no sbatch in sbatch


@pytest.mark.parametrize("dtype", [int, np.float32, str])
def test_good_model_params(dtype):
print(dtype(0.6))
params = {"foo": dtype(0.6)}
assert all(isinstance(val, str) for val in _parse_model_parameters(params).values())


@pytest.mark.parametrize("bad_val", [["eggs"], {"n": 5}, object])
def test_bad_model_params(bad_val):
with pytest.raises(TypeError):
_parse_model_parameters({"foo": bad_val})
2 changes: 1 addition & 1 deletion tests/test_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ def test_model_preview_properties(test_dir, wlmutils):
assert hw_rs == hello_world_model.run_settings.exe_args[0]
assert None == hello_world_model.batch_settings
assert "port" in list(hello_world_model.params.items())[0]
assert hw_port in list(hello_world_model.params.items())[0]
assert str(hw_port) in list(hello_world_model.params.items())[0]
assert "password" in list(hello_world_model.params.items())[1]
assert hw_password in list(hello_world_model.params.items())[1]

Expand Down