Skip to content

Commit

Permalink
#858 fix coverage and flake8
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjrobins committed Mar 15, 2020
1 parent 445ae54 commit 7ac7848
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 16 deletions.
12 changes: 4 additions & 8 deletions pybamm/expression_tree/state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,15 +242,13 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
return out

def diff(self, variable):
if variable.id == self.id:
return pybamm.Scalar(1)
elif variable.id == pybamm.t.id:
if variable.id == pybamm.t.id:
return StateVectorDot(*self._y_slices, name=self.name + "'",
domain=self.domain,
auxiliary_domains=self.auxiliary_domains,
evaluation_array=self.evaluation_array)
else:
return pybamm.Scalar(0)
raise NotImplementedError

def _jac(self, variable):
if isinstance(variable, pybamm.StateVector):
Expand Down Expand Up @@ -309,14 +307,12 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
return out

def diff(self, variable):
if variable.id == self.id:
return pybamm.Scalar(1)
elif variable.id == pybamm.t.id:
if variable.id == pybamm.t.id:
raise pybamm.ModelError(
"cannot take second time derivative of a state vector"
)
else:
return pybamm.Scalar(0)
raise NotImplementedError

def _jac(self, variable):
if isinstance(variable, pybamm.StateVectorDot):
Expand Down
6 changes: 3 additions & 3 deletions tests/unit/test_discretisations/test_discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,7 +703,8 @@ def test_process_model_ode(self):
# test that any time derivatives of variables in rhs raises an
# error
model = pybamm.BaseModel()
model.rhs = {c: pybamm.div(N) + c.diff(pybamm.t), T: pybamm.div(q), S: pybamm.div(p)}
model.rhs = {c: pybamm.div(N) + c.diff(pybamm.t),
T: pybamm.div(q), S: pybamm.div(p)}
model.initial_conditions = {
c: pybamm.Scalar(2),
T: pybamm.Scalar(5),
Expand Down Expand Up @@ -846,8 +847,6 @@ def test_process_model_dae(self):
with self.assertRaises(pybamm.ModelError):
disc.process_model(model)



def test_process_model_concatenation(self):
# concatenation of variables as the key
cn = pybamm.Variable("c", domain=["negative electrode"])
Expand Down Expand Up @@ -1144,6 +1143,7 @@ def test_mass_matirx_inverse(self):
model.mass_matrix_inv.entries.toarray(), mass_inv.toarray()
)


if __name__ == "__main__":
print("Add -v for more debug output")
import sys
Expand Down
7 changes: 4 additions & 3 deletions tests/unit/test_expression_tree/test_d_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def test_time_derivative(self):
self.assertEqual(a.simplify().id, (2 * pybamm.t).id)
self.assertEqual(a.evaluate(t=1), 2)

a =(2 + pybamm.t**2).diff(pybamm.t)
self.assertEqual(a.simplify().id, (2*pybamm.t).id)
a = (2 + pybamm.t**2).diff(pybamm.t)
self.assertEqual(a.simplify().id, (2 * pybamm.t).id)
self.assertEqual(a.evaluate(t=1), 2)

def test_time_derivative_of_variable(self):
Expand All @@ -33,7 +33,7 @@ def test_time_derivative_of_variable(self):
self.assertEqual(a.name, "a'")

p = pybamm.Parameter('p')
a = (1 + p*pybamm.Variable('a')).diff(pybamm.t).simplify()
a = (1 + p * pybamm.Variable('a')).diff(pybamm.t).simplify()
self.assertIsInstance(a, pybamm.Multiplication)
self.assertEqual(a.children[0].name, 'p')
self.assertEqual(a.children[1].name, "a'")
Expand All @@ -59,6 +59,7 @@ def test_time_derivative_of_state_vector(self):
with self.assertRaises(pybamm.ModelError):
a = (sv).diff(pybamm.t).diff(pybamm.t)


if __name__ == "__main__":
print("Add -v for more debug output")
import sys
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,8 @@ def test_convert_external_variable(self):

# External only
self.assert_casadi_equal(
pybamm_u1.to_casadi(casadi_t, casadi_y, u=casadi_us), casadi_us["External 1"]
pybamm_u1.to_casadi(casadi_t, casadi_y, u=casadi_us),
casadi_us["External 1"]
)

# More complex
Expand Down
13 changes: 12 additions & 1 deletion tests/unit/test_expression_tree/test_state_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,14 @@ def test_evaluate_list(self):
y = np.linspace(0, 3, 31)
np.testing.assert_array_almost_equal(sv.evaluate(y=y), y[:, np.newaxis])

def test_diff(self):
a = pybamm.StateVector(slice(0, 10))
with self.assertRaises(NotImplementedError):
a.diff(a)
b = pybamm.StateVectorDot(slice(0, 10))
with self.assertRaises(NotImplementedError):
a.diff(b)

def test_name(self):
sv = pybamm.StateVector(slice(0, 10))
self.assertEqual(sv.name, "y[0:10]")
Expand All @@ -61,6 +69,7 @@ def test_failure(self):
with self.assertRaisesRegex(TypeError, "all y_slices must be slice objects"):
pybamm.StateVector(slice(0, 10), 1)


class TestStateVectorDot(unittest.TestCase):
def test_evaluate(self):
sv = pybamm.StateVectorDot(slice(0, 10))
Expand All @@ -72,14 +81,16 @@ def test_evaluate(self):
# Try evaluating with a y that is too short
y_dot2 = np.ones(5)
with self.assertRaisesRegex(
ValueError, "y_dot is too short, so value with slice is smaller than expected"
ValueError,
"y_dot is too short, so value with slice is smaller than expected"
):
sv.evaluate(y_dot=y_dot2)

def test_name(self):
sv = pybamm.StateVectorDot(slice(0, 10))
self.assertEqual(sv.name, "y_dot[0:10]")


if __name__ == "__main__":
print("Add -v for more debug output")
import sys
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_parameters/test_parameters_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def test_list_params(self):
# but must not intefere with existing input dir if it exists
# in the current dir...


if __name__ == "__main__":
print("Add -v for more debug output")
import sys
Expand Down

0 comments on commit 7ac7848

Please sign in to comment.