Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added pchip Interpolator #4871

Merged
merged 8 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Added 'get_summary_variables' to return dictionary of computed summary variables ([#4824](https://github.com/pybamm-team/PyBaMM/pull/4824))
- Added support for particle size distributions combined with particle mechanics. ([#4807](https://github.com/pybamm-team/PyBaMM/pull/4807))
- Added InputParameter support in PyBamm experiments ([#4826](https://github.com/pybamm-team/PyBaMM/pull/4826))
- Added support for the `"pchip"` interpolator using the CasADI backend. ([#4871](https://github.com/pybamm-team/PyBaMM/pull/4871))

## Breaking changes

Expand Down
46 changes: 41 additions & 5 deletions src/pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,48 @@ def _convert(self, symbol, t, y, y_dot, inputs):
elif symbol.interpolator == "cubic":
solver = "bspline"
elif symbol.interpolator == "pchip":
raise NotImplementedError(
"The interpolator 'pchip' is not supported by CasAdi. "
"Use 'linear' or 'cubic' instead. "
"Alternatively, set 'model.convert_to_format = 'python'' "
"and use a non-CasADi solver. "
x_np = np.array(symbol.x[0])
y_np = np.array(symbol.y)
pchip_interp = interpolate.PchipInterpolator(x_np, y_np)
d_np = pchip_interp.derivative()(x_np)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we access this array directly from the properties of the pchip class?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes i think so as protected attribute

x = converted_children[0]

def hermite_poly(i):
x0 = x_np[i]
x1 = x_np[i + 1]
h_val = x1 - x0
h_val_mx = casadi.MX(h_val)
y0 = casadi.MX(y_np[i])
y1 = casadi.MX(y_np[i + 1])
d0 = casadi.MX(d_np[i])
d1 = casadi.MX(d_np[i + 1])
xn = (x - x0) / h_val_mx
h00 = 2 * xn**3 - 3 * xn**2 + 1
h10 = xn**3 - 2 * xn**2 + xn
h01 = -2 * xn**3 + 3 * xn**2
h11 = xn**3 - xn**2
return (
h00 * y0
+ h10 * h_val_mx * d0
+ h01 * y1
+ h11 * h_val_mx * d1
)

# Build piecewise polynomial for points inside the domain.
inside = casadi.MX.zeros(x.shape)
for i in range(len(x_np) - 1):
cond = casadi.logic_and(x >= x_np[i], x <= x_np[i + 1])
inside = casadi.if_else(cond, hermite_poly(i), inside)

# Extrapolation:
left = hermite_poly(0) # For x < x_np[0]
right = hermite_poly(len(x_np) - 2) # For x > x_np[-1]

# if greater than the maximum, use right; otherwise, use the piecewise value.
result = casadi.if_else(
x < x_np[0], left, casadi.if_else(x > x_np[-1], right, inside)
)
return result
else: # pragma: no cover
raise NotImplementedError(
f"Unknown interpolator: {symbol.interpolator}"
Expand Down
74 changes: 74 additions & 0 deletions tests/unit/test_experiments/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pybamm
import pytest
import numpy as np
import casadi
from scipy.interpolate import PchipInterpolator


class TestExperiment:
Expand Down Expand Up @@ -281,3 +283,75 @@ def test_voltage_without_directions(self):

voltage = solution["Terminal voltage [V]"].entries
assert np.allclose(voltage, 2.5, atol=1e-3)

def test_pchip_interpolation_experiment(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add a few more tests:

  • Comparing the results with scipy.interpolate.PchipInterpolator
  • Check non-uniform x grids
  • Check that it throws an error for non-increasing x inputs
  • Check left/right extrapolation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added these tests

x = np.linspace(0, 1, 11)
y_values = x**3

y = pybamm.StateVector(slice(0, 1))
interp = pybamm.Interpolant(x, y_values, y, interpolator="pchip")

test_points = np.linspace(0, 1, 21)
casadi_y = casadi.MX.sym("y", len(test_points), 1)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])

casadi_results = f(test_points.reshape((-1, 1)))
expected = interp.evaluate(y=test_points)
np.testing.assert_allclose(casadi_results, expected, rtol=1e-7, atol=1e-6)

def test_pchip_interpolation_uniform_grid(self):
x = np.linspace(0, 1, 11)
y_values = np.sin(x)

state = pybamm.StateVector(slice(0, 1))
interp = pybamm.Interpolant(x, y_values, state, interpolator="pchip")

test_points = np.linspace(0, 1, 21)
expected = PchipInterpolator(x, y_values)(test_points)

casadi_y = casadi.MX.sym("y", 1)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])
result = np.array(f(test_points)).flatten()

np.testing.assert_allclose(result, expected, rtol=1e-7, atol=1e-6)

def test_pchip_interpolation_nonuniform_grid(self):
x = np.array([0, 0.05, 0.2, 0.4, 0.65, 1.0])
y_values = np.exp(-x)
state = pybamm.StateVector(slice(0, 1))
interp = pybamm.Interpolant(x, y_values, state, interpolator="pchip")

test_points = np.linspace(0, 1, 21)
expected = PchipInterpolator(x, y_values)(test_points)

casadi_y = casadi.MX.sym("y", 1)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])
result = np.array(f(test_points)).flatten()

np.testing.assert_allclose(result, expected, rtol=1e-7, atol=1e-6)

def test_pchip_non_increasing_x(self):
x = np.array([0, 0.5, 0.5, 1.0])
y_values = np.linspace(0, 1, 4)
state = pybamm.StateVector(slice(0, 1))
with pytest.raises(ValueError, match="strictly increasing sequence"):
_ = pybamm.Interpolant(x, y_values, state, interpolator="pchip")

def test_pchip_extrapolation(self):
x = np.linspace(0, 1, 11)
y_values = np.log1p(x) # a smooth function on [0,1]
state = pybamm.StateVector(slice(0, 1))
interp = pybamm.Interpolant(x, y_values, state, interpolator="pchip")

test_points = np.array([-0.1, 1.1])
expected = PchipInterpolator(x, y_values)(test_points)

casadi_y = casadi.MX.sym("y", 1)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])
result = np.array(f(test_points)).flatten()

np.testing.assert_allclose(result, expected, rtol=1e-7, atol=1e-6)
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ def test_interpolation(self):
casadi_y = casadi.MX.sym("y", 2)
# linear
y_test = np.array([0.4, 0.6])
for interpolator in ["linear", "cubic"]:
for interpolator in ["linear", "cubic", "pchip"]:
interp = pybamm.Interpolant(x, 2 * x, y, interpolator=interpolator)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])
np.testing.assert_allclose(
interp.evaluate(y=y_test), f(y_test), rtol=1e-7, atol=1e-6
)
expected = interp.evaluate(y=y_test)
np.testing.assert_allclose(expected, f(y_test), rtol=1e-7, atol=1e-6)

# square
y = pybamm.StateVector(slice(0, 1))
interp = pybamm.Interpolant(x, x**2, y, interpolator="cubic")
Expand All @@ -188,11 +188,6 @@ def test_interpolation(self):
interp.evaluate(y=y_test), f(y_test), rtol=1e-7, atol=1e-6
)

# error for pchip interpolator
interp = pybamm.Interpolant(x, data, y, interpolator="pchip")
with pytest.raises(NotImplementedError, match="The interpolator"):
interp_casadi = interp.to_casadi(y=casadi_y)

# error for not recognized interpolator
with pytest.raises(ValueError, match="interpolator"):
interp = pybamm.Interpolant(x, data, y, interpolator="idonotexist")
Expand Down