Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added pchip Interpolator #4871

Merged
merged 8 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Added 'get_summary_variables' to return dictionary of computed summary variables ([#4824](https://github.com/pybamm-team/PyBaMM/pull/4824))
- Added support for particle size distributions combined with particle mechanics. ([#4807](https://github.com/pybamm-team/PyBaMM/pull/4807))
- Added InputParameter support in PyBamm experiments ([#4826](https://github.com/pybamm-team/PyBaMM/pull/4826))
- Added pchip Interpolator ([#4871](https://github.com/pybamm-team/PyBaMM/pull/4871))

## Breaking changes

Expand Down
42 changes: 36 additions & 6 deletions src/pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,42 @@ def _convert(self, symbol, t, y, y_dot, inputs):
elif symbol.interpolator == "cubic":
solver = "bspline"
elif symbol.interpolator == "pchip":
raise NotImplementedError(
"The interpolator 'pchip' is not supported by CasAdi. "
"Use 'linear' or 'cubic' instead. "
"Alternatively, set 'model.convert_to_format = 'python'' "
"and use a non-CasADi solver. "
)
x_np = np.array(symbol.x[0])
y_np = np.array(symbol.y)
pchip_interp = interpolate.PchipInterpolator(x_np, y_np, axis=0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the axis=0 argument is unnecessary because of the upstream check for 1D interpolation, but I could be wrong

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah its not required

d_np = pchip_interp.derivative()(x_np)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we access this array directly from the properties of the pchip class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i think so as protected attribute

x_sym = converted_children[0]
result = casadi.MX.zeros(x_sym.shape)

# Loop over each interval [x_np[i], x_np[i+1]]
for i in range(len(x_np) - 1):
x0 = x_np[i]
x1 = x_np[i + 1]
h_val = x1 - x0
h_val_mx = casadi.MX(h_val)
y0 = casadi.MX(y_np[i])
y1 = casadi.MX(y_np[i + 1])
d0 = casadi.MX(d_np[i])
d1 = casadi.MX(d_np[i + 1])

t = (x_sym - x0) / h_val_mx
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is minor, but can we keep a consistent notation with either t or x but not both?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed this to use x


# Define the Hermite basis functions.
h00 = 2 * t**3 - 3 * t**2 + 1
h10 = t**3 - 2 * t**2 + t
h01 = -2 * t**3 + 3 * t**2
h11 = t**3 - t**2

piece_val = (
h00 * y0
+ h10 * h_val_mx * d0
+ h01 * y1
+ h11 * h_val_mx * d1
)

cond = casadi.logic_and(x_sym >= x0, x_sym <= x1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we try to extrapolate outside the x bounds, this will return the value defined by

result = casadi.MX.zeros(x_sym.shape)

which can be problematic during simulation. Can you please add extrapolation that aligns with scipy's?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added extrapolation aligning with scipy

result = casadi.if_else(cond, piece_val, result)
return result
else: # pragma: no cover
raise NotImplementedError(
f"Unknown interpolator: {symbol.interpolator}"
Expand Down
17 changes: 17 additions & 0 deletions tests/unit/test_experiments/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pybamm
import pytest
import numpy as np
import casadi


class TestExperiment:
Expand Down Expand Up @@ -281,3 +282,19 @@ def test_voltage_without_directions(self):

voltage = solution["Terminal voltage [V]"].entries
assert np.allclose(voltage, 2.5, atol=1e-3)

def test_pchip_interpolation_experiment(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a few more tests:

  • Comparing the results with scipy.interpolate.PchipInterpolator
  • Check non-uniform x grids
  • Check that it throws an error for non-increasing x inputs
  • Check left/right extrapolation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added these tests

x = np.linspace(0, 1, 11)
y_values = x**3

y = pybamm.StateVector(slice(0, 1))
interp = pybamm.Interpolant(x, y_values, y, interpolator="pchip")

test_points = np.linspace(0, 1, 21)
casadi_y = casadi.MX.sym("y", len(test_points), 1)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])

casadi_results = f(test_points.reshape((-1, 1)))
expected = interp.evaluate(y=test_points)
np.testing.assert_allclose(casadi_results, expected, rtol=1e-7, atol=1e-6)
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ def test_interpolation(self):
casadi_y = casadi.MX.sym("y", 2)
# linear
y_test = np.array([0.4, 0.6])
for interpolator in ["linear", "cubic"]:
for interpolator in ["linear", "cubic", "pchip"]:
interp = pybamm.Interpolant(x, 2 * x, y, interpolator=interpolator)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])
np.testing.assert_allclose(
interp.evaluate(y=y_test), f(y_test), rtol=1e-7, atol=1e-6
)
expected = interp.evaluate(y=y_test)
np.testing.assert_allclose(expected, f(y_test), rtol=1e-7, atol=1e-6)

# square
y = pybamm.StateVector(slice(0, 1))
interp = pybamm.Interpolant(x, x**2, y, interpolator="cubic")
Expand All @@ -188,11 +188,6 @@ def test_interpolation(self):
interp.evaluate(y=y_test), f(y_test), rtol=1e-7, atol=1e-6
)

# error for pchip interpolator
interp = pybamm.Interpolant(x, data, y, interpolator="pchip")
with pytest.raises(NotImplementedError, match="The interpolator"):
interp_casadi = interp.to_casadi(y=casadi_y)

# error for not recognized interpolator
with pytest.raises(ValueError, match="interpolator"):
interp = pybamm.Interpolant(x, data, y, interpolator="idonotexist")
Expand Down