Skip to content

Commit

Permalink
#1889 fix test again
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Jan 20, 2022
1 parent ea8c9aa commit 50e364c
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 10 deletions.
2 changes: 0 additions & 2 deletions pybamm/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,8 +285,6 @@ def timescale(self):
@timescale.setter
def timescale(self, value):
"""Set the timescale"""
if not isinstance(value, (numbers.Number, pybamm.Scalar)):
raise ValueError("model.timescale must be a scalar")
self._timescale = value

@property
Expand Down
5 changes: 4 additions & 1 deletion pybamm/solvers/base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,10 @@ def set_up(self, model, inputs=None, t_eval=None, ics_only=False):
inputs = inputs or {}

# Set model timescale
model.timescale_eval = model.timescale.evaluate(inputs=inputs)
if not isinstance(model.timescale, pybamm.Scalar):
raise ValueError("model.timescale must be a scalar")

model.timescale_eval = model.timescale.evaluate()
# Set model lengthscales
model.length_scales_eval = {
domain: scale.evaluate(inputs=inputs)
Expand Down
7 changes: 0 additions & 7 deletions tests/unit/test_models/test_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,13 +935,6 @@ def test_set_variables_error(self):
with self.assertRaisesRegex(ValueError, "not var"):
model.variables = {"not var": var}

def test_timescale(self):
model = pybamm.BaseModel()
model.timescale = 2.5
self.assertEqual(model.timescale, 2.5)
with self.assertRaisesRegex(ValueError, "must be a scalar"):
model.timescale = pybamm.InputParameter("a")


if __name__ == "__main__":
print("Add -v for more debug output")
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/test_solvers/test_base_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,18 @@ def test_convert_to_casadi_format(self):
self.assertEqual(model.convert_to_format, "casadi")
pybamm.set_logging_level("WARNING")

def test_timescale_input_fail(self):
# Make sure timescale can't depend on inputs
model = pybamm.BaseModel()
v = pybamm.Variable("v")
model.rhs = {v: -1}
model.initial_conditions = {v: 1}
a = pybamm.InputParameter("a")
model.timescale = a
solver = pybamm.BaseSolver()
with self.assertRaisesRegex(ValueError, "model.timescale must be a scalar"):
solver.set_up(model)

def test_inputs_step(self):
# Make sure interpolant inputs are dropped
model = pybamm.BaseModel()
Expand Down

0 comments on commit 50e364c

Please sign in to comment.