Skip to content

Commit

Permalink
#247 fix processed variable
Browse files Browse the repository at this point in the history
  • Loading branch information
valentinsulzer committed Feb 19, 2020
1 parent fe7995a commit ec948de
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 16 deletions.
14 changes: 9 additions & 5 deletions pybamm/processed_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,22 +149,24 @@ def initialise_2D(self):
self.entries = entries
self.dimensions = 2
if self.domain[0] in ["negative particle", "positive particle"]:
self.spatial_var_name = "r"
self.first_dimension = "r"
self.r_sol = space
elif self.domain[0] in [
"negative electrode",
"separator",
"positive electrode",
]:
self.spatial_var_name = "x"
self.first_dimension = "x"
self.x_sol = space
elif self.domain == ["current collector"]:
self.spatial_var_name = "z"
self.first_dimension = "z"
self.z_sol = space
else:
self.spatial_var_name = "x"
self.first_dimension = "x"
self.x_sol = space

self.first_dim_pts = space

# set up interpolation
# note that the order of 't' and 'space' is the reverse of what you'd expect

Expand Down Expand Up @@ -242,6 +244,8 @@ def initialise_3D(self):
# assign attributes for reference
self.entries = entries
self.dimensions = 3
self.first_dim_pts = first_dim_pts
self.second_dim_pts = second_dim_pts

# set up interpolation
self._interpolation_function = interp.RegularGridInterpolator(
Expand Down Expand Up @@ -335,7 +339,7 @@ def __call__(self, t=None, x=None, r=None, y=None, z=None, warn=True):

def call_2D(self, t, x, r, z):
"Evaluate a 2D variable"
spatial_var = eval_dimension_name(self.spatial_var_name, x, r, None, z)
spatial_var = eval_dimension_name(self.first_dimension, x, r, None, z)
return self._interpolation_function(t, spatial_var)

def call_3D(self, t, x, r, y, z):
Expand Down
18 changes: 9 additions & 9 deletions tests/integration/test_models/standard_output_comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ def test_all(self, skip_first_timestep=False):
self.run_test_class(VariablesComparison, skip_first_timestep)

if self.chemistry == "Lithium-ion":
self.run_test_class(ParticleConcentrationComparison)
self.run_test_class(ParticleConcentrationComparison, skip_first_timestep)
elif self.chemistry == "Lead-acid":
self.run_test_class(PorosityComparison)
self.run_test_class(PorosityComparison, skip_first_timestep)


class BaseOutputComparison(object):
Expand All @@ -67,13 +67,14 @@ def compare(self, var, tol=1e-2):
model_variables = [solution[var] for solution in self.solutions]
var0 = model_variables[0]

if var0.mesh is None:
x = None
else:
x = var0.mesh[0].nodes
spatial_pts = {}
if var0.dimensions >= 2:
spatial_pts[var0.first_dimension] = var0.first_dim_pts
if var0.dimensions >= 3:
spatial_pts[var0.second_dimension] = var0.second_dim_pts

# Calculate tolerance based on the value of var0
maxvar0 = np.max(abs(var0(self.t, x)))
maxvar0 = np.max(abs(var0(self.t, **spatial_pts)))
if maxvar0 < 1e-14:
decimal = -int(np.log10(tol))
else:
Expand All @@ -82,7 +83,7 @@ def compare(self, var, tol=1e-2):
for model_var in model_variables[1:]:
np.testing.assert_equal(var0.dimensions, model_var.dimensions)
np.testing.assert_array_almost_equal(
model_var(self.t, x), var0(self.t, x), decimal
model_var(self.t, **spatial_pts), var0(self.t, **spatial_pts), decimal
)


Expand Down Expand Up @@ -123,7 +124,6 @@ def test_all(self):
self.compare("X-averaged negative electrode open circuit potential")
self.compare("X-averaged positive electrode open circuit potential")
self.compare("Terminal voltage")
self.compare("X-averaged electrolyte overpotential")
self.compare("X-averaged solid phase ohmic losses")
self.compare("Negative electrode reaction overpotential")
self.compare("Positive electrode reaction overpotential")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_compare_outputs_surface_form(self):

# solve model
solutions = []
t_eval = np.linspace(0, 0.2, 100)
t_eval = np.linspace(0, 3600, 100)
for i, model in enumerate(models):
solution = pybamm.CasadiSolver().solve(model, t_eval)
solutions.append(solution)
Expand All @@ -58,5 +58,4 @@ def test_compare_outputs_surface_form(self):

if "-v" in sys.argv:
debug = True
pybamm.set_logging_level("DEBUG")
unittest.main()

0 comments on commit ec948de

Please sign in to comment.