Skip to content

Commit

Permalink
Added pchip Interpolator (#4871)
Browse files Browse the repository at this point in the history
* added pchip interpolator

* added changelog and few minor changes

* fixing tests

* added extrapolation and tests

* Update CHANGELOG.md

Co-authored-by: Marc Berliner <[email protected]>

---------

Co-authored-by: Marc Berliner <[email protected]>
  • Loading branch information
Rishab87 and MarcBerliner authored Feb 24, 2025
1 parent cd272f7 commit 603fe6f
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 14 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Added 'get_summary_variables' to return dictionary of computed summary variables ([#4824](https://github.com/pybamm-team/PyBaMM/pull/4824))
- Added support for particle size distributions combined with particle mechanics. ([#4807](https://github.com/pybamm-team/PyBaMM/pull/4807))
- Added InputParameter support in PyBamm experiments ([#4826](https://github.com/pybamm-team/PyBaMM/pull/4826))
- Added support for the `"pchip"` interpolator using the CasADI backend. ([#4871](https://github.com/pybamm-team/PyBaMM/pull/4871))

## Breaking changes

Expand Down
46 changes: 41 additions & 5 deletions src/pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,48 @@ def _convert(self, symbol, t, y, y_dot, inputs):
elif symbol.interpolator == "cubic":
solver = "bspline"
elif symbol.interpolator == "pchip":
raise NotImplementedError(
"The interpolator 'pchip' is not supported by CasAdi. "
"Use 'linear' or 'cubic' instead. "
"Alternatively, set 'model.convert_to_format = 'python'' "
"and use a non-CasADi solver. "
x_np = np.array(symbol.x[0])
y_np = np.array(symbol.y)
pchip_interp = interpolate.PchipInterpolator(x_np, y_np)
d_np = pchip_interp.derivative()(x_np)
x = converted_children[0]

def hermite_poly(i):
x0 = x_np[i]
x1 = x_np[i + 1]
h_val = x1 - x0
h_val_mx = casadi.MX(h_val)
y0 = casadi.MX(y_np[i])
y1 = casadi.MX(y_np[i + 1])
d0 = casadi.MX(d_np[i])
d1 = casadi.MX(d_np[i + 1])
xn = (x - x0) / h_val_mx
h00 = 2 * xn**3 - 3 * xn**2 + 1
h10 = xn**3 - 2 * xn**2 + xn
h01 = -2 * xn**3 + 3 * xn**2
h11 = xn**3 - xn**2
return (
h00 * y0
+ h10 * h_val_mx * d0
+ h01 * y1
+ h11 * h_val_mx * d1
)

# Build piecewise polynomial for points inside the domain.
inside = casadi.MX.zeros(x.shape)
for i in range(len(x_np) - 1):
cond = casadi.logic_and(x >= x_np[i], x <= x_np[i + 1])
inside = casadi.if_else(cond, hermite_poly(i), inside)

# Extrapolation:
left = hermite_poly(0) # For x < x_np[0]
right = hermite_poly(len(x_np) - 2) # For x > x_np[-1]

# if greater than the maximum, use right; otherwise, use the piecewise value.
result = casadi.if_else(
x < x_np[0], left, casadi.if_else(x > x_np[-1], right, inside)
)
return result
else: # pragma: no cover
raise NotImplementedError(
f"Unknown interpolator: {symbol.interpolator}"
Expand Down
74 changes: 74 additions & 0 deletions tests/unit/test_experiments/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import pybamm
import pytest
import numpy as np
import casadi
from scipy.interpolate import PchipInterpolator


class TestExperiment:
Expand Down Expand Up @@ -281,3 +283,75 @@ def test_voltage_without_directions(self):

voltage = solution["Terminal voltage [V]"].entries
assert np.allclose(voltage, 2.5, atol=1e-3)

def test_pchip_interpolation_experiment(self):
x = np.linspace(0, 1, 11)
y_values = x**3

y = pybamm.StateVector(slice(0, 1))
interp = pybamm.Interpolant(x, y_values, y, interpolator="pchip")

test_points = np.linspace(0, 1, 21)
casadi_y = casadi.MX.sym("y", len(test_points), 1)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])

casadi_results = f(test_points.reshape((-1, 1)))
expected = interp.evaluate(y=test_points)
np.testing.assert_allclose(casadi_results, expected, rtol=1e-7, atol=1e-6)

def test_pchip_interpolation_uniform_grid(self):
x = np.linspace(0, 1, 11)
y_values = np.sin(x)

state = pybamm.StateVector(slice(0, 1))
interp = pybamm.Interpolant(x, y_values, state, interpolator="pchip")

test_points = np.linspace(0, 1, 21)
expected = PchipInterpolator(x, y_values)(test_points)

casadi_y = casadi.MX.sym("y", 1)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])
result = np.array(f(test_points)).flatten()

np.testing.assert_allclose(result, expected, rtol=1e-7, atol=1e-6)

def test_pchip_interpolation_nonuniform_grid(self):
x = np.array([0, 0.05, 0.2, 0.4, 0.65, 1.0])
y_values = np.exp(-x)
state = pybamm.StateVector(slice(0, 1))
interp = pybamm.Interpolant(x, y_values, state, interpolator="pchip")

test_points = np.linspace(0, 1, 21)
expected = PchipInterpolator(x, y_values)(test_points)

casadi_y = casadi.MX.sym("y", 1)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])
result = np.array(f(test_points)).flatten()

np.testing.assert_allclose(result, expected, rtol=1e-7, atol=1e-6)

def test_pchip_non_increasing_x(self):
x = np.array([0, 0.5, 0.5, 1.0])
y_values = np.linspace(0, 1, 4)
state = pybamm.StateVector(slice(0, 1))
with pytest.raises(ValueError, match="strictly increasing sequence"):
_ = pybamm.Interpolant(x, y_values, state, interpolator="pchip")

def test_pchip_extrapolation(self):
x = np.linspace(0, 1, 11)
y_values = np.log1p(x) # a smooth function on [0,1]
state = pybamm.StateVector(slice(0, 1))
interp = pybamm.Interpolant(x, y_values, state, interpolator="pchip")

test_points = np.array([-0.1, 1.1])
expected = PchipInterpolator(x, y_values)(test_points)

casadi_y = casadi.MX.sym("y", 1)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])
result = np.array(f(test_points)).flatten()

np.testing.assert_allclose(result, expected, rtol=1e-7, atol=1e-6)
Original file line number Diff line number Diff line change
Expand Up @@ -159,13 +159,13 @@ def test_interpolation(self):
casadi_y = casadi.MX.sym("y", 2)
# linear
y_test = np.array([0.4, 0.6])
for interpolator in ["linear", "cubic"]:
for interpolator in ["linear", "cubic", "pchip"]:
interp = pybamm.Interpolant(x, 2 * x, y, interpolator=interpolator)
interp_casadi = interp.to_casadi(y=casadi_y)
f = casadi.Function("f", [casadi_y], [interp_casadi])
np.testing.assert_allclose(
interp.evaluate(y=y_test), f(y_test), rtol=1e-7, atol=1e-6
)
expected = interp.evaluate(y=y_test)
np.testing.assert_allclose(expected, f(y_test), rtol=1e-7, atol=1e-6)

# square
y = pybamm.StateVector(slice(0, 1))
interp = pybamm.Interpolant(x, x**2, y, interpolator="cubic")
Expand All @@ -188,11 +188,6 @@ def test_interpolation(self):
interp.evaluate(y=y_test), f(y_test), rtol=1e-7, atol=1e-6
)

# error for pchip interpolator
interp = pybamm.Interpolant(x, data, y, interpolator="pchip")
with pytest.raises(NotImplementedError, match="The interpolator"):
interp_casadi = interp.to_casadi(y=casadi_y)

# error for not recognized interpolator
with pytest.raises(ValueError, match="interpolator"):
interp = pybamm.Interpolant(x, data, y, interpolator="idonotexist")
Expand Down

0 comments on commit 603fe6f

Please sign in to comment.