diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py index 1589018b85..eed365960b 100644 --- a/pybamm/models/base_model.py +++ b/pybamm/models/base_model.py @@ -285,8 +285,6 @@ def timescale(self): @timescale.setter def timescale(self, value): """Set the timescale""" - if not isinstance(value, (numbers.Number, pybamm.Scalar)): - raise ValueError("model.timescale must be a scalar") self._timescale = value @property diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py index 4cbbbf549b..2f770d231f 100644 --- a/pybamm/solvers/base_solver.py +++ b/pybamm/solvers/base_solver.py @@ -143,7 +143,10 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False): inputs = inputs or {} # Set model timescale - model.timescale_eval = model.timescale.evaluate(inputs=inputs) + if not isinstance(model.timescale, pybamm.Scalar): + raise ValueError("model.timescale must be a scalar") + + model.timescale_eval = model.timescale.evaluate() # Set model lengthscales model.length_scales_eval = { domain: scale.evaluate(inputs=inputs) diff --git a/tests/unit/test_models/test_base_model.py b/tests/unit/test_models/test_base_model.py index 69a8ca48cf..c56d2802f3 100644 --- a/tests/unit/test_models/test_base_model.py +++ b/tests/unit/test_models/test_base_model.py @@ -935,13 +935,6 @@ def test_set_variables_error(self): with self.assertRaisesRegex(ValueError, "not var"): model.variables = {"not var": var} - def test_timescale(self): - model = pybamm.BaseModel() - model.timescale = 2.5 - self.assertEqual(model.timescale, 2.5) - with self.assertRaisesRegex(ValueError, "must be a scalar"): - model.timescale = pybamm.InputParameter("a") - if __name__ == "__main__": print("Add -v for more debug output") diff --git a/tests/unit/test_solvers/test_base_solver.py b/tests/unit/test_solvers/test_base_solver.py index c35325e0ba..9d5bcd9bce 100644 --- a/tests/unit/test_solvers/test_base_solver.py +++ b/tests/unit/test_solvers/test_base_solver.py @@ -282,6 +282,18 @@ def test_convert_to_casadi_format(self): self.assertEqual(model.convert_to_format, "casadi") pybamm.set_logging_level("WARNING") + def test_timescale_input_fail(self): + # Make sure timescale can't depend on inputs + model = pybamm.BaseModel() + v = pybamm.Variable("v") + model.rhs = {v: -1} + model.initial_conditions = {v: 1} + a = pybamm.InputParameter("a") + model.timescale = a + solver = pybamm.BaseSolver() + with self.assertRaisesRegex(ValueError, "model.timescale must be a scalar"): + solver.set_up(model) + def test_inputs_step(self): # Make sure interpolant inputs are dropped model = pybamm.BaseModel()