Skip to content

Commit

Permalink
Add quadratic subroutine for quartic solver
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiefl committed Jan 18, 2025
1 parent c8146b8 commit 7e904b3
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 12 deletions.
58 changes: 50 additions & 8 deletions pooltool/ptmath/roots/quartic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pooltool.ptmath.roots.core import (
get_real_positive_smallest_roots,
)
from pooltool.ptmath.roots.quadratic import solve as solve_quadratic
from pooltool.utils.strenum import StrEnum, auto


Expand All @@ -16,6 +17,33 @@ class QuarticSolver(StrEnum):
NUMERIC = auto()


@jit(nopython=True, cache=const.use_numba_cache)
def _solve_quadratics(ps: NDArray[np.float64]) -> NDArray[np.complex128]:
"""Solves an array of quadratics.
This is used internally by the quartic solver when it is passed coefficients where
a=b=0, which make the polynomial quadratic, not quartic.
Args:
ps:
A mx3 array of polynomial coefficients, where m is the number of equations.
The columns are in the order a, b, c where these coefficients make up
the quadratic polynomial equation at^2 + bt + c = 0.
Notes:
- Output shape is mx4 to match quartic root solutions.
"""
m = ps.shape[0]
roots = np.full((m, 4), np.inf, dtype=np.complex128)

for i in range(m):
r1, r2 = solve_quadratic(ps[i, 0], ps[i, 1], ps[i, 2])
roots[i, 0] = r1
roots[i, 1] = r2

return roots


def solve_quartics(
ps: NDArray[np.float64],
solver: QuarticSolver = QuarticSolver.HYBRID,
Expand Down Expand Up @@ -43,16 +71,30 @@ def solve_quartics(
"""
assert QuarticSolver(solver)

if (ps[:, 0] == 0).any():
raise NotImplementedError(
"This quartic solver has not implemented cubic (a=0) and quadratic (a=b=0) "
"formulations, but at least one of the equations passed is cubic/quadratic. "
)
m = ps.shape[0]

# Allocate a placeholder for all roots, shape (m,4)
all_roots = np.full((m, 4), np.inf, dtype=np.complex128)

a = ps[:, 0]
b = ps[:, 1]

quartic_mask = a != 0
quadratic_mask = (a == 0) & (b == 0)
mask_cubic = (a == 0) & (b != 0)

if np.any(mask_cubic):
raise NotImplementedError("Cubic polynomials are not supported.")

if np.any(quartic_mask):
quartic_roots = _quartic_routine[solver](ps[quartic_mask])
all_roots[quartic_mask] = quartic_roots

# Get the roots for the polynomials
roots = _quartic_routine[solver](ps) # Shape m x 4, dtype complex128
best_roots = get_real_positive_smallest_roots(roots) # Shape m, dtype float64
if np.any(quadratic_mask):
quadratic_roots = _solve_quadratics(ps[quadratic_mask, 2:])
all_roots[quadratic_mask] = quadratic_roots

best_roots = get_real_positive_smallest_roots(all_roots) # shape (m,)
return best_roots


Expand Down
15 changes: 11 additions & 4 deletions tests/ptmath/roots/test_quartic.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import pytest

from pooltool.ptmath.roots import quartic
from pooltool.ptmath.roots import quadratic, quartic
from pooltool.ptmath.roots.core import get_real_positive_smallest_roots


@pytest.mark.parametrize(
Expand All @@ -27,11 +28,17 @@ def test_case1(solver: quartic.QuarticSolver):
"solver", [quartic.QuarticSolver.NUMERIC, quartic.QuarticSolver.HYBRID]
)
def test_quadratic(solver: quartic.QuarticSolver):
"""This test surfaces the fact that quartic solver can't handle quadratic equations :("""
coeffs_array = np.array((0, 0, 1, 1, 1), dtype=np.float64)[np.newaxis, :]

with pytest.raises(NotImplementedError):
quartic.solve_quartics(coeffs_array, solver)
expected = get_real_positive_smallest_roots(
np.array(quadratic.solve(*coeffs_array[0, 2:]), dtype=np.complex128)[
np.newaxis, :
]
)

result = quartic.solve_quartics(coeffs_array, solver)

assert expected == result


@pytest.mark.parametrize(
Expand Down

0 comments on commit 7e904b3

Please sign in to comment.