Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement 3D pocket-ball collision logic #179

Merged
merged 11 commits into from
Jan 20, 2025
55 changes: 29 additions & 26 deletions pooltool/evolution/event_based/simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@
from pooltool.evolution.event_based.config import INCLUDED_EVENTS
from pooltool.objects.ball.datatypes import BallState
from pooltool.physics.engine import PhysicsEngine
from pooltool.ptmath.roots.quartic import QuarticSolver, solve_quartics
from pooltool.ptmath.roots.core import (
get_smallest_physical_root_many,
)
from pooltool.ptmath.roots.quartic import (
QuarticSolver,
solve_quartics,
)
from pooltool.system.datatypes import System

DEFAULT_ENGINE = PhysicsEngine()
Expand Down Expand Up @@ -343,7 +349,9 @@ def get_next_ball_ball_collision(
)

if len(collision_coeffs):
roots = solve_quartics(ps=np.array(collision_coeffs), solver=solver)
roots = get_smallest_physical_root_many(
solve_quartics(ps=np.array(collision_coeffs), solver=solver)
)
for root, ball_pair in zip(roots, ball_pairs):
cache[ball_pair] = shot.t + root

Expand Down Expand Up @@ -437,7 +445,9 @@ def get_next_ball_circular_cushion_event(
)

if len(collision_coeffs):
roots = solve_quartics(ps=np.array(collision_coeffs), solver=solver)
roots = get_smallest_physical_root_many(
solve_quartics(ps=np.array(collision_coeffs), solver=solver)
)

for root, ball_cushion_pair in zip(roots, ball_cushion_pairs):
cache[ball_cushion_pair] = shot.t + root
Expand Down Expand Up @@ -508,17 +518,16 @@ def get_next_ball_pocket_collision(
collision_cache: CollisionCache,
solver: QuarticSolver = QuarticSolver.HYBRID,
) -> Event:
"""Returns next ball-pocket collision"""
"""Returns next ball-pocket collision

# FIXME-3D no ball-pocket collisions
return null_event(np.inf)
Notes:
- FIXME-3D Passing solver does nothing, as the underlying solve method uses the
HYBRID approach (quartic.solve). Not sure what the solution should be.
"""

if not shot.table.has_pockets:
return null_event(np.inf)

ball_pocket_pairs: List[Tuple[str, str]] = []
collision_coeffs: List[Tuple[float, ...]] = []

cache = collision_cache.times.setdefault(EventType.BALL_POCKET, {})

for ball in shot.balls.values():
Expand All @@ -535,25 +544,19 @@ def get_next_ball_pocket_collision(
cache[obj_ids] = np.inf
continue

ball_pocket_pairs.append(obj_ids)
collision_coeffs.append(
solve.ball_pocket_collision_coeffs(
rvw=state.rvw,
s=state.s,
a=pocket.a,
b=pocket.b,
r=pocket.radius,
mu=(params.u_s if state.s == const.sliding else params.u_r),
m=params.m,
g=params.g,
R=params.R,
)
dtau_E = solve.ball_pocket_collision_time(
rvw=state.rvw,
s=state.s,
a=pocket.a,
b=pocket.b,
r=pocket.radius,
mu=(params.u_s if state.s == const.sliding else params.u_r),
m=params.m,
g=params.g,
R=params.R,
)

if len(collision_coeffs):
roots = solve_quartics(ps=np.array(collision_coeffs), solver=solver)
for root, ball_pocket_pair in zip(roots, ball_pocket_pairs):
cache[ball_pocket_pair] = shot.t + root
cache[obj_ids] = shot.t + dtau_E

# The cache is now populated and up-to-date

Expand Down
184 changes: 137 additions & 47 deletions pooltool/evolution/event_based/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import pooltool.physics as physics
import pooltool.physics.evolve as evolve
import pooltool.ptmath as ptmath
from pooltool.physics.utils import get_airborne_time
from pooltool.ptmath.roots import quadratic, quartic
from pooltool.ptmath.roots.core import (
filter_non_physical_roots,
)


@jit(nopython=True, cache=const.use_numba_cache)
Expand Down Expand Up @@ -43,10 +48,7 @@ def ball_ball_collision_coeffs(
g2: float,
R: float,
) -> Tuple[float, float, float, float, float]:
"""Get quartic coeffs required to determine the ball-ball collision time

(just-in-time compiled)
"""
"""Get quartic coeffs required to determine the ball-ball collision time."""

c1x, c1y = rvw1[0, 0], rvw1[0, 1]
c2x, c2y = rvw2[0, 0], rvw2[0, 1]
Expand Down Expand Up @@ -105,10 +107,7 @@ def ball_table_collision_time(
g: float,
R: float,
) -> float:
"""Get time until collision between ball and table surface.

(just-in-time compiled)
"""
"""Get time until collision between ball and table surface."""
if s != const.airborne:
# Non-airborne ball cannot have a ball-table collision
return np.inf
Expand All @@ -131,10 +130,7 @@ def ball_linear_cushion_collision_time(
g: float,
R: float,
) -> float:
"""Get time until collision between ball and linear cushion segment

(just-in-time compiled)
"""
"""Get time until collision between ball and linear cushion segment."""
if s == const.spinning or s == const.pocketed or s == const.stationary:
return np.inf

Expand Down Expand Up @@ -162,41 +158,38 @@ def ball_linear_cushion_collision_time(
# C must be 0, but whether or not it is, time is a free parameter.
return np.inf

roots = np.full(4, np.nan, dtype=np.complex128)

if direction == 0:
C = l0 + lx * cx + ly * cy + R * np.sqrt(lx**2 + ly**2)
root1, root2 = ptmath.roots.quadratic.solve(A, B, C)
roots = [root1, root2]
roots[:2] = quadratic.solve(A, B, C)
elif direction == 1:
C = l0 + lx * cx + ly * cy - R * np.sqrt(lx**2 + ly**2)
root1, root2 = ptmath.roots.quadratic.solve(A, B, C)
roots = [root1, root2]
roots[:2] = quadratic.solve(A, B, C)
else:
C1 = l0 + lx * cx + ly * cy + R * np.sqrt(lx**2 + ly**2)
C2 = l0 + lx * cx + ly * cy - R * np.sqrt(lx**2 + ly**2)
root1, root2 = ptmath.roots.quadratic.solve(A, B, C1)
root3, root4 = ptmath.roots.quadratic.solve(A, B, C2)
roots = [root1, root2, root3, root4]

min_time = np.inf
for root in roots:
if np.isnan(root):
# This is an indirect test for whether the root is complex or not. This is
# because ptmath.roots.quadratic.solve returns nan if the root is complex.
roots[:2] = quadratic.solve(A, B, C1)
roots[2:] = quadratic.solve(A, B, C2)

physical_roots = filter_non_physical_roots(roots)

for root in physical_roots:
if root.real == np.inf:
continue

# FIXME-3D, ideally any sort of determination of real versus not is determined
# in filter_non_physical_roots. Remove this and observe behavior closely.
if root.real <= const.EPS:
continue

rvw_dtau = evolve.evolve_ball_motion(s, rvw, R, m, mu, 1, mu, g, root)
rvw_dtau = evolve.evolve_ball_motion(s, rvw, R, m, mu, 1, mu, g, root.real)
s_score = -np.dot(p1 - rvw_dtau[0], p2 - p1) / np.dot(p2 - p1, p2 - p1)

if not (0 <= s_score <= 1):
continue

if root.real < min_time:
min_time = root.real
if 0 <= s_score <= 1:
return root.real

return min_time
return np.inf


@jit(nopython=True, cache=const.use_numba_cache)
Expand All @@ -211,10 +204,7 @@ def ball_circular_cushion_collision_coeffs(
g: float,
R: float,
) -> Tuple[float, float, float, float, float]:
"""Get quartic coeffs required to determine the ball-circular-cushion collision time

(just-in-time compiled)
"""
"""Get quartic coeffs required to determine the ball-circular-cushion collision time."""

if s == const.spinning or s == const.pocketed or s == const.stationary:
return np.inf, np.inf, np.inf, np.inf, np.inf
Expand Down Expand Up @@ -246,7 +236,7 @@ def ball_circular_cushion_collision_coeffs(


@jit(nopython=True, cache=const.use_numba_cache)
def ball_pocket_collision_coeffs(
def ball_pocket_collision_time(
rvw: NDArray[np.float64],
s: int,
a: float,
Expand All @@ -256,26 +246,30 @@ def ball_pocket_collision_coeffs(
m: float,
g: float,
R: float,
) -> Tuple[float, float, float, float, float]:
"""Get quartic coeffs required to determine the ball-pocket collision time
) -> float:
"""Determine the ball-pocket collision time.

(just-in-time compiled)
The behavior for airborne versus non-airborne state is treated differently. This
function delegates to :func:`ball_pocket_collision_time_airborne` when the state is
airborne.
"""

if s == const.spinning or s == const.pocketed or s == const.stationary:
return np.inf, np.inf, np.inf, np.inf, np.inf

phi = ptmath.projected_angle(rvw[1])
v = ptmath.norm3d(rvw[1])
return np.inf

u = get_u(rvw, R, phi, s)
if s == const.airborne:
return ball_pocket_collision_time_airborne(rvw, a, b, r, g, R)

K = -0.5 * mu * g
phi = ptmath.projected_angle(rvw[1])
v = ptmath.norm2d(rvw[1])
cos_phi = np.cos(phi)
sin_phi = np.sin(phi)

u = get_u(rvw, R, phi, s)
K = -0.5 * mu * g
ax = K * (u[0] * cos_phi - u[1] * sin_phi)
ay = K * (u[0] * sin_phi + u[1] * cos_phi)

bx, by = v * cos_phi, v * sin_phi
cx, cy = rvw[0, 0], rvw[0, 1]

Expand All @@ -285,4 +279,100 @@ def ball_pocket_collision_coeffs(
D = bx * (cx - a) + by * (cy - b)
E = 0.5 * (a**2 + b**2 + cx**2 + cy**2 - r**2) - (cx * a + cy * b)

return A, B, C, D, E
roots = quartic.solve(A, B, C, D, E)
return filter_non_physical_roots(roots)[0].real


@jit(nopython=True, cache=const.use_numba_cache)
def ball_pocket_collision_time_airborne(
rvw: NDArray[np.float64],
a: float,
b: float,
r: float,
g: float,
R: float,
) -> float:
"""Determine the ball-pocket collision time for an airborne ball.

The behavior is somewhat complicated. Here is the procedure.

Strategy 1: The xy-coordinates of where the ball lands are calculated. If that falls
within the pocket circle, a collision is returned. The collision time is chosen to
be just less than the collision time for the table collision, to guarantee temporal
precedence over the table collision.

Strategy 2: Otherwise, the influx and outflux collision times are calculated between
the ball center and a vertical cylinder that extends from the pocket's circle.
Influx collision refers to the collision with the outside of the cylinder's wall.
The outflux collision refers to the collision with the inside of the cylinder's wall
and occurs later in time. Since there is no deceleration in the xy-plane for an
airborne ball, an outflux collision is expected, meaning we expect 2 finite roots.
(This is only violated if the ball starts inside the cylinder, which results in at
most an outflux collision). The strategy is to see what the ball height is at the
time of the influx collision (h0) and the outflux collision (hf), because from these
we can determine whether or not the ball is considered to enter the pocket. The
following logic is used:

- h0 < R: The ball passes through the playing surface plane before intersecting
the pocket cylinder, guaranteeing that a ball-table collision occurs. Infinity
is returned.
- hf <= (7/5)*R: If the outflux height is less than (7/5)*R, the ball is
considered to be pocketed. This threshold height implicitly models the fact
that high velocity balls that are slightly airborne collide with table
geometry at the back of the pocket, ricocheting the ball into the pocket. The
average of the influx and outflux collision times is returned.
- hf > (7/5)*R: The ball is considered to fly over the pocket. Infinity is
returned.
"""

phi = ptmath.projected_angle(rvw[1])
v = ptmath.norm2d(rvw[1])
cos_phi = np.cos(phi)
sin_phi = np.sin(phi)
bx, by = v * cos_phi, v * sin_phi

# Strategy 1

airborne_time = get_airborne_time(rvw, R, g)
x = rvw[0, 0] + bx * airborne_time
y = rvw[0, 1] + by * airborne_time

if (x - a) ** 2 + (y - b) ** 2 < r**2:
# The ball falls directly into the pocket
return float(airborne_time - const.EPS)

# Strategy 2

cx, cy = rvw[0, 0], rvw[0, 1]

# These match the non-airborne quartic coefficients, after setting ax=ay=0.
C = 0.5 * (bx**2 + by**2)
D = bx * (cx - a) + by * (cy - b)
E = 0.5 * (a**2 + b**2 + cx**2 + cy**2 - r**2) - (cx * a + cy * b)

r1, r2 = filter_non_physical_roots(quadratic.solve(C, D, E)).real

if r1 == np.inf:
return r1

assert r2 != np.inf, "Expected finite out-flux collision with pocket"

v0z = rvw[1, 2]
z0 = rvw[0, 2]

# Height at influx collision and height at outflux collision
h0 = -0.5 * g * r1**2 + v0z * r1 + z0
hf = -0.5 * g * r2**2 + v0z * r1 + z0

if h0 < R:
# Ball hits table before reaching pocket. Safe to return inf
assert hf < h0
return np.inf

thresh = 7 / 5 * R
if hf > thresh:
# Ball flies over pocket
return np.inf

# Return average time of influx/outflux collisions
return (r1 + r2) / 2.0
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def cue_strike(m, M, R, V0, phi, theta, a, b, english_throttle: float):
denominator = 1 + m / M + 5 / 2 / R**2 * temp
v = numerator / denominator

# 3D FIXME
v_B = -v * np.array([0, np.cos(theta), np.sin(theta)])

vec_x = -c * np.sin(theta) + b * np.cos(theta)
Expand Down
Loading
Loading