diff --git a/doc/changelog.md b/doc/changelog.md index 23bbed5c6..26388a05e 100644 --- a/doc/changelog.md +++ b/doc/changelog.md @@ -15,6 +15,8 @@ To be released at some future point in time Description +- Allow specifying Model and Ensemble parameters with + number-like types (e.g. numpy types) - Pin watchdog to 4.x - Update codecov to 4.5.0 - Remove build of Redis from setup.py @@ -31,6 +33,11 @@ Description Detailed Notes +- The serializer would fail if a parameter for a Model or Ensemble + was specified as a numpy dtype. The constructors for these + methods now validate that the input is number-like and convert + them to strings + ([SmartSim-PR676](https://github.com/CrayLabs/SmartSim/pull/676)) - Pin watchdog to 4.x because v5 introduces new types and requires updates to the type-checking ([SmartSim-PR690](https://github.com/CrayLabs/SmartSim/pull/690)) diff --git a/smartsim/entity/model.py b/smartsim/entity/model.py index 3f78e042c..a11a594fc 100644 --- a/smartsim/entity/model.py +++ b/smartsim/entity/model.py @@ -27,6 +27,7 @@ from __future__ import annotations import itertools +import numbers import re import sys import typing as t @@ -46,6 +47,25 @@ logger = get_logger(__name__) +def _parse_model_parameters(params_dict: t.Dict[str, t.Any]) -> t.Dict[str, str]: + """Convert the values in a params dict to strings + :raises TypeError: if params are of the wrong type + :return: param dictionary with values and keys cast as strings + """ + param_names: t.List[str] = [] + parameters: t.List[str] = [] + for name, val in params_dict.items(): + param_names.append(name) + if isinstance(val, (str, numbers.Number)): + parameters.append(str(val)) + else: + raise TypeError( + "Incorrect type for model parameters\n" + + "Must be numeric value or string." + ) + return dict(zip(param_names, parameters)) + + class Model(SmartSimEntity): def __init__( self, @@ -70,7 +90,7 @@ def __init__( model as a batch job """ super().__init__(name, str(path), run_settings) - self.params = params + self.params = _parse_model_parameters(params) self.params_as_args = params_as_args self.incoming_entities: t.List[SmartSimEntity] = [] self._key_prefixing_enabled = False diff --git a/tests/test_model.py b/tests/test_model.py index 64a68b299..152ce2058 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -26,12 +26,14 @@ from uuid import uuid4 +import numpy as np import pytest from smartsim import Experiment from smartsim._core.control.manifest import LaunchedManifestBuilder from smartsim._core.launcher.step import SbatchStep, SrunStep from smartsim.entity import Ensemble, Model +from smartsim.entity.model import _parse_model_parameters from smartsim.error import EntityExistsError, SSUnsupportedError from smartsim.settings import RunSettings, SbatchSettings, SrunSettings from smartsim.settings.mpiSettings import _BaseMPISettings @@ -176,3 +178,16 @@ def test_models_batch_settings_are_ignored_in_ensemble( step_cmd = step.step_cmds[0] assert any("srun" in tok for tok in step_cmd) # call the model using run settings assert not any("sbatch" in tok for tok in step_cmd) # no sbatch in sbatch + + +@pytest.mark.parametrize("dtype", [int, np.float32, str]) +def test_good_model_params(dtype): + print(dtype(0.6)) + params = {"foo": dtype(0.6)} + assert all(isinstance(val, str) for val in _parse_model_parameters(params).values()) + + +@pytest.mark.parametrize("bad_val", [["eggs"], {"n": 5}, object]) +def test_bad_model_params(bad_val): + with pytest.raises(TypeError): + _parse_model_parameters({"foo": bad_val}) diff --git a/tests/test_preview.py b/tests/test_preview.py index 3c7bed6fe..a18d10728 100644 --- a/tests/test_preview.py +++ b/tests/test_preview.py @@ -357,7 +357,7 @@ def test_model_preview_properties(test_dir, wlmutils): assert hw_rs == hello_world_model.run_settings.exe_args[0] assert None == hello_world_model.batch_settings assert "port" in list(hello_world_model.params.items())[0] - assert hw_port in list(hello_world_model.params.items())[0] + assert str(hw_port) in list(hello_world_model.params.items())[0] assert "password" in list(hello_world_model.params.items())[1] assert hw_password in list(hello_world_model.params.items())[1]