Skip to content

Commit

Permalink
Implement second derivative for Hermite splines (#30)
Browse files Browse the repository at this point in the history
  • Loading branch information
faymanns authored Oct 14, 2024
1 parent 9583f0b commit d6ccb70
Show file tree
Hide file tree
Showing 11 changed files with 80 additions and 24 deletions.
2 changes: 1 addition & 1 deletion docs/pyplots/plot_b1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

b1 = splinebox.basis_functions.B1()

t = np.linspace(-2, 2, 100)
t = np.linspace(-2, 2, 1000)

b1_0th = b1.eval(t)
b1_1st = b1.eval(t, derivative=1)
Expand Down
2 changes: 1 addition & 1 deletion docs/pyplots/plot_b2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

b2 = splinebox.basis_functions.B2()

t = np.linspace(-2, 2, 100)
t = np.linspace(-2, 2, 1000)

b2_0th = b2.eval(t)
b2_1st = b2.eval(t, derivative=1)
Expand Down
2 changes: 1 addition & 1 deletion docs/pyplots/plot_b3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

b3 = splinebox.basis_functions.B3()

t = np.linspace(-3, 3, 100)
t = np.linspace(-3, 3, 1000)

b3_0th = b3.eval(t)
b3_1st = b3.eval(t, derivative=1)
Expand Down
2 changes: 1 addition & 1 deletion docs/pyplots/plot_catmullrom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

basis = splinebox.basis_functions.CatmullRom()

t = np.linspace(-3, 3, 100)
t = np.linspace(-3, 3, 1000)

basis_0th = basis.eval(t)
basis_1st = basis.eval(t, derivative=1)
Expand Down
13 changes: 9 additions & 4 deletions docs/pyplots/plot_cubichermite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,31 @@
import numpy as np
import splinebox.basis_functions

fig, axes = plt.subplots(2, 2, sharex=True)
fig, axes = plt.subplots(3, 2, sharex=True)

basis = splinebox.basis_functions.CubicHermite()

t = np.linspace(-3, 3, 100)
t = np.linspace(-3, 3, 1000)

basis_0th = basis.eval(t)
basis_1st = basis.eval(t, derivative=1)
basis_2nd = basis.eval(t, derivative=2)

fig.suptitle("Cubic Hermite basis function and its derivatives")
axes[0][0].plot(t, basis_0th[0], label=r"$\Phi_1(t)$")
axes[0][0].legend()
axes[1][0].plot(t, basis_1st[0], label=r"$\frac{d\Phi_1}{dt}(t)$")
axes[1][0].legend()
axes[1][0].set_xlabel(r"$t$")
axes[2][0].plot(t, basis_2nd[0], label=r"$\frac{d^2\Phi_1}{dt^2}(t)$")
axes[2][0].legend()
axes[2][0].set_xlabel(r"$t$")
axes[0][1].plot(t, basis_0th[1], label=r"$\Phi_2(t)$")
axes[0][1].legend()
axes[1][1].plot(t, basis_1st[1], label=r"$\frac{d\Phi_2}{dt}(t)$")
axes[1][1].legend()
axes[1][1].set_xlabel(r"$t$")
axes[2][1].plot(t, basis_2nd[1], label=r"$\frac{d^2\Phi_2}{dt^2}(t)$")
axes[2][1].legend()
axes[2][1].set_xlabel(r"$t$")
plt.tight_layout()

plt.show()
2 changes: 1 addition & 1 deletion docs/pyplots/plot_exponential.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
M = 5
basis = splinebox.basis_functions.Exponential(M=M)

t = np.linspace(-3, 3, 100)
t = np.linspace(-3, 3, 1000)

basis_0th = basis.eval(t)
basis_1st = basis.eval(t, derivative=1)
Expand Down
13 changes: 9 additions & 4 deletions docs/pyplots/plot_exponentialhermite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,32 @@
import numpy as np
import splinebox.basis_functions

fig, axes = plt.subplots(2, 2, sharex=True)
fig, axes = plt.subplots(3, 2, sharex=True)

M = 5
basis = splinebox.basis_functions.ExponentialHermite(M=M)

t = np.linspace(-3, 3, 100)
t = np.linspace(-3, 3, 1000)

basis_0th = basis.eval(t)
basis_1st = basis.eval(t, derivative=1)
basis_2nd = basis.eval(t, derivative=2)

fig.suptitle(f"Exponential Hermite basis function and its derivatives for $M={M}$")
axes[0][0].plot(t, basis_0th[0], label=r"$\Phi_1(t)$")
axes[0][0].legend()
axes[1][0].plot(t, basis_1st[0], label=r"$\frac{d\Phi_1}{dt}(t)$")
axes[1][0].legend()
axes[1][0].set_xlabel(r"$t$")
axes[2][0].plot(t, basis_2nd[0], label=r"$\frac{d^2\Phi_1}{dt^2}(t)$")
axes[2][0].legend()
axes[2][0].set_xlabel(r"$t$")
axes[0][1].plot(t, basis_0th[1], label=r"$\Phi_2(t)$")
axes[0][1].legend()
axes[1][1].plot(t, basis_1st[1], label=r"$\frac{d\Phi_2}{dt}(t)$")
axes[1][1].legend()
axes[1][1].set_xlabel(r"$t$")
axes[2][1].plot(t, basis_2nd[1], label=r"$\frac{d^2\Phi_2}{dt^2}(t)$")
axes[2][1].legend()
axes[2][1].set_xlabel(r"$t$")
plt.tight_layout()

plt.show()
2 changes: 1 addition & 1 deletion docs/pyplots/plot_no_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

spline.control_points = control_points

t = np.linspace(0, spline.M - 1, 100)
t = np.linspace(0, spline.M - 1, 1000)
vals = spline.eval(t)

plt.figure(figsize=(6, 3))
Expand Down
2 changes: 1 addition & 1 deletion docs/pyplots/plot_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

spline.control_points = control_points

t = np.linspace(0, spline.M - 1, 100)
t = np.linspace(0, spline.M - 1, 1000)
vals = spline.eval(t)

plt.figure(figsize=(6, 3))
Expand Down
58 changes: 54 additions & 4 deletions src/splinebox/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,9 +739,28 @@ def h32prime(t): # pragma: no cover
val = 3 * t * t + 4 * t + 1
return val

def _derivative_2(self, t):
return np.array([self.h31primeprime(t), self.h32primeprime(t)])

@staticmethod
def _derivative_2(t):
raise RuntimeError("CubicHermite isn't twice differentiable.")
@numba.vectorize([numba.float64(numba.float64)], nopython=True, cache=True)
def h31primeprime(t): # pragma: no cover
val = 0
if t >= 0 and t <= 1:
val = 12 * t - 6
elif t < 0 and t >= -1:
val = -12 * t - 6
return val

@staticmethod
@numba.vectorize([numba.float64(numba.float64)], nopython=True, cache=True)
def h32primeprime(t): # pragma: no cover
val = 0
if t >= 0 and t <= 1:
val = 6 * t - 4
elif t < 0 and t >= -1:
val = 6 * t + 4
return val

def h31_autocorrelation(self, i, j, M): # pragma: no cover
"""
Expand Down Expand Up @@ -972,9 +991,40 @@ def _g2prime(t, M):
val = _g2prime(t, M) if t >= 0 else _g2prime(-t, M)
return val

def _derivative_2(self, t):
return np.array([self._he31primeprime(t, self.M), self._he32primeprime(t, self.M)])

@staticmethod
@numba.vectorize([numba.float64(numba.float64, numba.float64)], nopython=True, cache=True)
def _he31primeprime(t, M): # pragma: no cover
def _g1primeprime(t, M):
val = 0
if t >= 0 and t <= 1:
alpha = np.pi / M
denom = (alpha * np.cos(alpha)) - np.sin(alpha)
num = 2 * alpha**2 * np.sin(alpha - (2 * alpha * t))
val = num / denom
return val

val = _g1primeprime(t, M) if t >= 0 else _g1primeprime(-t, M)
return val

@staticmethod
def _derivative_2(x):
raise RuntimeError("ExponentialHermite isn't twice differentiable.")
@numba.vectorize([numba.float64(numba.float64, numba.float64)], nopython=True, cache=True)
def _he32primeprime(t, M): # pragma: no cover
def _g2primeprime(t, M):
val = 0
if t >= 0 and t <= 1:
alpha = np.pi / M
denom = ((alpha * np.cos(alpha)) - np.sin(alpha)) * 8 * alpha * np.sin(alpha)
num = +(8 * alpha**2 * np.sin(alpha) * np.cos(2 * alpha * (t - 0.5))) - (
8 * alpha**3 * np.cos(2 * alpha * (t - 1))
)
val = num / denom
return val

val = _g2primeprime(t, M) if t >= 0 else -1 * _g2primeprime(-t, M)
return val

@staticmethod
def filter_symmetric(s):
Expand Down
6 changes: 1 addition & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,7 @@ def _not_differentiable_twice(obj):
basis_function = obj.basis_function if is_spline(obj) else obj
return isinstance(
basis_function,
(
splinebox.basis_functions.B1,
splinebox.basis_functions.CubicHermite,
splinebox.basis_functions.ExponentialHermite,
),
(splinebox.basis_functions.B1,),
)

return _not_differentiable_twice
Expand Down

0 comments on commit d6ccb70

Please sign in to comment.