diff --git a/poetry.lock b/poetry.lock index f347bbff..53c3411b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -850,6 +850,22 @@ files = [ {file = "installer-0.7.0.tar.gz", hash = "sha256:a26d3e3116289bb08216e0d0f7d925fcef0b0194eedfa0c944bcaaa106c4b631"}, ] +[[package]] +name = "ipdb" +version = "0.13.13" +description = "IPython-enabled pdb" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +files = [ + {file = "ipdb-0.13.13-py3-none-any.whl", hash = "sha256:45529994741c4ab6d2388bfa5d7b725c2cf7fe9deffabdb8a6113aa5ed449ed4"}, + {file = "ipdb-0.13.13.tar.gz", hash = "sha256:e3ac6018ef05126d442af680aad863006ec19d02290561ac88b8b1c0b0cfc726"}, +] + +[package.dependencies] +decorator = {version = "*", markers = "python_version > \"3.6\""} +ipython = {version = ">=7.31.1", markers = "python_version > \"3.6\""} +tomli = {version = "*", markers = "python_version > \"3.6\" and python_version < \"3.11\""} + [[package]] name = "ipython" version = "8.18.1" @@ -2708,4 +2724,4 @@ test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "338ef1e8dc6ee8b9787d470f6c7d3404ca85f898b1bd3c6429d19a4d3afce297" +content-hash = "04a6d5859d7eb66caf0b2099295895f644f458f39840d25fb97c1b3e77ad44a4" diff --git a/pooltool/events/datatypes.py b/pooltool/events/datatypes.py index 8bfb9898..741d6b9e 100644 --- a/pooltool/events/datatypes.py +++ b/pooltool/events/datatypes.py @@ -58,22 +58,22 @@ class EventType(strenum.StrEnum): def is_collision(self) -> bool: """Returns whether the member is a collision""" - return self in ( + return self in { EventType.BALL_BALL, EventType.BALL_CIRCULAR_CUSHION, EventType.BALL_LINEAR_CUSHION, EventType.BALL_POCKET, EventType.STICK_BALL, - ) + } def is_transition(self) -> bool: """Returns whether the member is a transition""" - return self in ( + return self in { EventType.SPINNING_STATIONARY, EventType.ROLLING_STATIONARY, EventType.ROLLING_SPINNING, EventType.SLIDING_ROLLING, - ) + } Object = Union[ diff --git a/pooltool/events/utils.py b/pooltool/events/utils.py new file mode 100644 index 00000000..fdb76139 --- /dev/null +++ b/pooltool/events/utils.py @@ -0,0 +1,15 @@ +from typing import Dict, Set + +from pooltool.events.datatypes import EventType + +event_type_to_ball_indices: Dict[EventType, Set[int]] = { + EventType.BALL_BALL: {0, 1}, + EventType.BALL_LINEAR_CUSHION: {0}, + EventType.BALL_CIRCULAR_CUSHION: {0}, + EventType.BALL_POCKET: {0}, + EventType.STICK_BALL: {1}, + EventType.SPINNING_STATIONARY: {0}, + EventType.ROLLING_STATIONARY: {0}, + EventType.ROLLING_SPINNING: {0}, + EventType.SLIDING_ROLLING: {0}, +} diff --git a/pooltool/evolution/event_based/cache.py b/pooltool/evolution/event_based/cache.py new file mode 100644 index 00000000..df765849 --- /dev/null +++ b/pooltool/evolution/event_based/cache.py @@ -0,0 +1,163 @@ +#! /usr/bin/env python + +from __future__ import annotations + +from typing import Dict, Set, Tuple + +import attrs +import numpy as np + +import pooltool.constants as const +import pooltool.ptmath as ptmath +from pooltool.events import ( + AgentType, + Event, + EventType, + null_event, + rolling_spinning_transition, + rolling_stationary_transition, + sliding_rolling_transition, + spinning_stationary_transition, +) +from pooltool.events.utils import event_type_to_ball_indices +from pooltool.objects.ball.datatypes import Ball +from pooltool.system.datatypes import System + + +def _null() -> Dict[str, Event]: + return {"null": null_event(time=np.inf)} + + +@attrs.define +class TransitionCache: + """A cache for managing and retrieving the next transition events for balls. + + This class maintains a dictionary of transitions, where each key is ball ID, and + each value is the next transition associated with that object. + + Attributes: + transitions: + A dictionary mapping ball IDs to their corresponding next event. + + See Also: + - For practical and historical reasons, events are cached differently depending + on whether they are collision events or transition events. For collision + event caching, see :class:`CollisionCache`. + """ + + transitions: Dict[str, Event] = attrs.field(factory=_null) + + def get_next(self) -> Event: + return min( + (trans for trans in self.transitions.values()), key=lambda event: event.time + ) + + def update(self, event: Event) -> None: + """Update transition cache for all balls in Event""" + for agent in event.agents: + if agent.agent_type == AgentType.BALL: + assert isinstance(ball := agent.final, Ball) + self.transitions[agent.id] = _next_transition(ball) + + @classmethod + def create(cls, shot: System) -> TransitionCache: + return cls( + {ball_id: _next_transition(ball) for ball_id, ball in shot.balls.items()} + ) + + +def _next_transition(ball: Ball) -> Event: + if ball.state.s == const.stationary or ball.state.s == const.pocketed: + return null_event(time=np.inf) + + elif ball.state.s == const.spinning: + dtau_E = ptmath.get_spin_time( + ball.state.rvw, ball.params.R, ball.params.u_sp, ball.params.g + ) + return spinning_stationary_transition(ball, ball.state.t + dtau_E) + + elif ball.state.s == const.rolling: + dtau_E_spin = ptmath.get_spin_time( + ball.state.rvw, ball.params.R, ball.params.u_sp, ball.params.g + ) + dtau_E_roll = ptmath.get_roll_time( + ball.state.rvw, ball.params.u_r, ball.params.g + ) + + if dtau_E_spin > dtau_E_roll: + return rolling_spinning_transition(ball, ball.state.t + dtau_E_roll) + else: + return rolling_stationary_transition(ball, ball.state.t + dtau_E_roll) + + elif ball.state.s == const.sliding: + dtau_E = ptmath.get_slide_time( + ball.state.rvw, ball.params.R, ball.params.u_s, ball.params.g + ) + return sliding_rolling_transition(ball, ball.state.t + dtau_E) + + else: + raise NotImplementedError(f"Unknown '{ball.state.s=}'") + + +@attrs.define +class CollisionCache: + """A cache for storing and managing collision times between objects. + + This class is used as a cache for possible future collision times. By caching + collision times whenever they are calculated, re-calculating them during each step + of the shot evolution algorithm can be avoided in many instances. + + It also provides functionality to invalidate cached events based on realized events, + ensuring that outdated data does not persist in the cache. For example, if a + ball-transition event for ball with ID "6" is passed to :meth:`invalidate`, all + cached event times involving ball ID "6" are removed from the cache, since they are + no longer valid. + + Attributes: + times: + A dictionary where each key is an event type, and each value is another + dictionary mapping tuples of object IDs to their corresponding collision + times. + + Properties: + size: + The total number of cached events. + + See Also: + - For practical and historical reasons, events are cached differently depending + on whether they are collision events or transition events. For transition + event caching, see :class:`TransitionCache`. + """ + + times: Dict[EventType, Dict[Tuple[str, str], float]] = attrs.field(factory=dict) + + @property + def size(self) -> int: + return sum(len(cache) for cache in self.times.values()) + + def _get_invalid_ball_ids(self, event: Event) -> Set[str]: + return { + event.ids[ball_idx] + for ball_idx in event_type_to_ball_indices[event.event_type] + } + + def invalidate(self, event: Event) -> None: + invalid_ball_ids = self._get_invalid_ball_ids(event) + + for event_type, event_times in self.times.items(): + keys_to_delete = [] + + for key in event_times: + # Identify which indices in the key should be checked based on the event type + ball_indices = event_type_to_ball_indices.get(event_type, []) + + # Check if any of the relevant ball IDs in the key match the invalid IDs + if any(key[idx] in invalid_ball_ids for idx in ball_indices): + keys_to_delete.append(key) + + for key in keys_to_delete: + del event_times[key] + + @classmethod + def create(cls) -> CollisionCache: + return cls() diff --git a/pooltool/evolution/event_based/simulate.py b/pooltool/evolution/event_based/simulate.py index faf2509c..ef862d89 100755 --- a/pooltool/evolution/event_based/simulate.py +++ b/pooltool/evolution/event_based/simulate.py @@ -3,16 +3,14 @@ from __future__ import annotations from itertools import combinations -from typing import Dict, Optional, Set +from typing import List, Optional, Set, Tuple -import attrs import numpy as np import pooltool.constants as const import pooltool.physics.evolve as evolve import pooltool.ptmath as ptmath from pooltool.events import ( - AgentType, Event, EventType, ball_ball_collision, @@ -20,28 +18,43 @@ ball_linear_cushion_collision, ball_pocket_collision, null_event, - rolling_spinning_transition, - rolling_stationary_transition, - sliding_rolling_transition, - spinning_stationary_transition, stick_ball_collision, ) from pooltool.evolution.continuize import continuize from pooltool.evolution.event_based import solve +from pooltool.evolution.event_based.cache import CollisionCache, TransitionCache from pooltool.evolution.event_based.config import INCLUDED_EVENTS -from pooltool.objects.ball.datatypes import Ball, BallState -from pooltool.objects.table.components import ( - CircularCushionSegment, - LinearCushionSegment, - Pocket, -) +from pooltool.objects.ball.datatypes import BallState from pooltool.physics.engine import PhysicsEngine -from pooltool.ptmath.roots.quartic import QuarticSolver +from pooltool.ptmath.roots.quartic import QuarticSolver, solve_quartics from pooltool.system.datatypes import System DEFAULT_ENGINE = PhysicsEngine() +def _evolve(shot: System, dt: float): + """Evolves current ball an amount of time dt + + FIXME This is very inefficent. each ball should store its natural trajectory + thereby avoid a call to the clunky evolve_ball_motion. It could even be a + partial function so parameters don't continuously need to be passed + """ + + for ball in shot.balls.values(): + rvw, _ = evolve.evolve_ball_motion( + state=ball.state.s, + rvw=ball.state.rvw, + R=ball.params.R, + m=ball.params.m, + u_s=ball.params.u_s, + u_sp=ball.params.u_sp, + u_r=ball.params.u_r, + g=ball.params.g, + t=dt, + ) + ball.state = BallState(rvw, ball.state.s, shot.t + dt) + + def simulate( shot: System, engine: Optional[PhysicsEngine] = None, @@ -152,6 +165,7 @@ def simulate( engine.resolver.resolve(shot, event) shot._update_history(event) + collision_cache = CollisionCache.create() transition_cache = TransitionCache.create(shot) events = 0 @@ -159,6 +173,7 @@ def simulate( event = get_next_event( shot, transition_cache=transition_cache, + collision_cache=collision_cache, quartic_solver=quartic_solver, ) @@ -171,6 +186,7 @@ def simulate( if event.event_type in include: engine.resolver.resolve(shot, event) transition_cache.update(event) + collision_cache.invalidate(event) shot._update_history(event) @@ -190,33 +206,11 @@ def simulate( return shot -def _evolve(shot: System, dt: float): - """Evolves current ball an amount of time dt - - FIXME This is very inefficent. each ball should store its natural trajectory - thereby avoid a call to the clunky evolve_ball_motion. It could even be a - partial function so parameters don't continuously need to be passed - """ - - for ball in shot.balls.values(): - rvw, _ = evolve.evolve_ball_motion( - state=ball.state.s, - rvw=ball.state.rvw, - R=ball.params.R, - m=ball.params.m, - u_s=ball.params.u_s, - u_sp=ball.params.u_sp, - u_r=ball.params.u_r, - g=ball.params.g, - t=dt, - ) - ball.state = BallState(rvw, ball.state.s, shot.t + dt) - - def get_next_event( shot: System, *, transition_cache: Optional[TransitionCache] = None, + collision_cache: Optional[CollisionCache] = None, quartic_solver: QuarticSolver = QuarticSolver.HYBRID, ) -> Event: # Start by assuming next event doesn't happen @@ -225,101 +219,57 @@ def get_next_event( if transition_cache is None: transition_cache = TransitionCache.create(shot) + if collision_cache is None: + collision_cache = CollisionCache.create() + transition_event = transition_cache.get_next() if transition_event.time < event.time: event = transition_event - ball_ball_event = get_next_ball_ball_collision(shot, solver=quartic_solver) + ball_ball_event = get_next_ball_ball_collision( + shot, collision_cache=collision_cache, solver=quartic_solver + ) if ball_ball_event.time < event.time: event = ball_ball_event - ball_linear_cushion_event = get_next_ball_linear_cushion_collision(shot) - if ball_linear_cushion_event.time < event.time: - event = ball_linear_cushion_event - ball_circular_cushion_event = get_next_ball_circular_cushion_event( - shot, solver=quartic_solver + shot, collision_cache=collision_cache, solver=quartic_solver ) if ball_circular_cushion_event.time < event.time: event = ball_circular_cushion_event - ball_pocket_event = get_next_ball_pocket_collision(shot, solver=quartic_solver) + ball_linear_cushion_event = get_next_ball_linear_cushion_collision( + shot, collision_cache=collision_cache + ) + if ball_linear_cushion_event.time < event.time: + event = ball_linear_cushion_event + + ball_pocket_event = get_next_ball_pocket_collision( + shot, collision_cache=collision_cache, solver=quartic_solver + ) if ball_pocket_event.time < event.time: event = ball_pocket_event return event -def _null() -> Dict[str, Event]: - return {"null": null_event(time=np.inf)} - - -@attrs.define -class TransitionCache: - transitions: Dict[str, Event] = attrs.field(factory=_null) - - def get_next(self) -> Event: - return min( - (trans for trans in self.transitions.values()), key=lambda event: event.time - ) - - def update(self, event: Event) -> None: - """Update transition cache for all balls in Event""" - for agent in event.agents: - if agent.agent_type == AgentType.BALL: - assert isinstance(ball := agent.final, Ball) - self.transitions[agent.id] = _next_transition(ball) - - @classmethod - def create(cls, shot: System) -> TransitionCache: - return cls( - {ball_id: _next_transition(ball) for ball_id, ball in shot.balls.items()} - ) - - -def _next_transition(ball: Ball) -> Event: - if ball.state.s == const.stationary or ball.state.s == const.pocketed: - return null_event(time=np.inf) - - elif ball.state.s == const.spinning: - dtau_E = ptmath.get_spin_time( - ball.state.rvw, ball.params.R, ball.params.u_sp, ball.params.g - ) - return spinning_stationary_transition(ball, ball.state.t + dtau_E) - - elif ball.state.s == const.rolling: - dtau_E_spin = ptmath.get_spin_time( - ball.state.rvw, ball.params.R, ball.params.u_sp, ball.params.g - ) - dtau_E_roll = ptmath.get_roll_time( - ball.state.rvw, ball.params.u_r, ball.params.g - ) - - if dtau_E_spin > dtau_E_roll: - return rolling_spinning_transition(ball, ball.state.t + dtau_E_roll) - else: - return rolling_stationary_transition(ball, ball.state.t + dtau_E_roll) - - elif ball.state.s == const.sliding: - dtau_E = ptmath.get_slide_time( - ball.state.rvw, ball.params.R, ball.params.u_s, ball.params.g - ) - return sliding_rolling_transition(ball, ball.state.t + dtau_E) - - else: - raise NotImplementedError(f"Unknown '{ball.state.s=}'") - - def get_next_ball_ball_collision( - shot: System, solver: QuarticSolver = QuarticSolver.HYBRID + shot: System, + collision_cache: CollisionCache, + solver: QuarticSolver = QuarticSolver.HYBRID, ) -> Event: """Returns next ball-ball collision""" - dtau_E = np.inf - ball_ids = [] - collision_coeffs = [] + ball_pairs: List[Tuple[str, str]] = [] + collision_coeffs: List[Tuple[float, ...]] = [] + + cache = collision_cache.times.setdefault(EventType.BALL_BALL, {}) for ball1, ball2 in combinations(shot.balls.values(), 2): + ball_pair = (ball1.id, ball2.id) + if ball_pair in cache: + continue + ball1_state = ball1.state ball1_params = ball1.params @@ -327,78 +277,90 @@ def get_next_ball_ball_collision( ball2_params = ball2.params if ball1_state.s == const.pocketed or ball2_state.s == const.pocketed: - continue - - if ( + cache[ball_pair] = np.inf + elif ( ball1_state.s in const.nontranslating and ball2_state.s in const.nontranslating ): - continue - - if ( + cache[ball_pair] = np.inf + elif ( ptmath.norm3d(ball1_state.rvw[0] - ball2_state.rvw[0]) < ball1_params.R + ball2_params.R ): # If balls are intersecting, avoid internal collisions - continue - - collision_coeffs.append( - solve.ball_ball_collision_coeffs( - rvw1=ball1_state.rvw, - rvw2=ball2_state.rvw, - s1=ball1_state.s, - s2=ball2_state.s, - mu1=( - ball1_params.u_s - if ball1_state.s == const.sliding - else ball1_params.u_r - ), - mu2=( - ball2_params.u_s - if ball2_state.s == const.sliding - else ball2_params.u_r - ), - m1=ball1_params.m, - m2=ball2_params.m, - g1=ball1_params.g, - g2=ball2_params.g, - R=ball1_params.R, + cache[ball_pair] = np.inf + else: + ball_pairs.append(ball_pair) + collision_coeffs.append( + solve.ball_ball_collision_coeffs( + rvw1=ball1_state.rvw, + rvw2=ball2_state.rvw, + s1=ball1_state.s, + s2=ball2_state.s, + mu1=( + ball1_params.u_s + if ball1_state.s == const.sliding + else ball1_params.u_r + ), + mu2=( + ball2_params.u_s + if ball2_state.s == const.sliding + else ball2_params.u_r + ), + m1=ball1_params.m, + m2=ball2_params.m, + g1=ball1_params.g, + g2=ball2_params.g, + R=ball1_params.R, + ) ) - ) - - ball_ids.append((ball1.id, ball2.id)) - if not len(collision_coeffs): - # There are no collisions to test for - return ball_ball_collision(Ball.dummy(), Ball.dummy(), shot.t + dtau_E) + if len(collision_coeffs): + roots = solve_quartics(ps=np.array(collision_coeffs), solver=solver) + for root, ball_pair in zip(roots, ball_pairs): + cache[ball_pair] = shot.t + root - dtau_E, index = ptmath.roots.quartic.minimum_quartic_root( - ps=np.array(collision_coeffs), solver=solver - ) + # The cache is now populated and up-to-date - ball1_id, ball2_id = ball_ids[index] - ball1, ball2 = shot.balls[ball1_id], shot.balls[ball2_id] + ball_pair = min(cache, key=lambda k: cache[k]) - return ball_ball_collision(ball1, ball2, shot.t + dtau_E) + return ball_ball_collision( + ball1=shot.balls[ball_pair[0]], + ball2=shot.balls[ball_pair[1]], + time=cache[ball_pair], + ) def get_next_ball_circular_cushion_event( - shot: System, solver: QuarticSolver = QuarticSolver.HYBRID + shot: System, + collision_cache: CollisionCache, + solver: QuarticSolver = QuarticSolver.HYBRID, ) -> Event: """Returns next ball-cushion collision (circular cushion segment)""" - dtau_E = np.inf - agent_ids = [] - collision_coeffs = [] + if not shot.table.has_circular_cushions: + return null_event(np.inf) - for ball in shot.balls.values(): - if ball.state.s in const.nontranslating: - continue + ball_cushion_pairs: List[Tuple[str, str]] = [] + collision_coeffs: List[Tuple[float, ...]] = [] + cache = collision_cache.times.setdefault(EventType.BALL_CIRCULAR_CUSHION, {}) + + for ball in shot.balls.values(): state = ball.state params = ball.params for cushion in shot.table.cushion_segments.circular.values(): + obj_ids = (ball.id, cushion.id) + + if obj_ids in cache: + continue + + if ball.state.s in const.nontranslating: + cache[obj_ids] = np.inf + continue + + ball_cushion_pairs.append(obj_ids) collision_coeffs.append( solve.ball_circular_cushion_collision_coeffs( rvw=state.rvw, @@ -413,41 +375,46 @@ def get_next_ball_circular_cushion_event( ) ) - agent_ids.append((ball.id, cushion.id)) + if len(collision_coeffs): + roots = solve_quartics(ps=np.array(collision_coeffs), solver=solver) + for root, ball_cushion_pair in zip(roots, ball_cushion_pairs): + cache[ball_cushion_pair] = shot.t + root - if not len(collision_coeffs): - # There are no collisions to test for - return ball_circular_cushion_collision( - Ball.dummy(), CircularCushionSegment.dummy(), shot.t + dtau_E - ) + # The cache is now populated and up-to-date - dtau_E, index = ptmath.roots.quartic.minimum_quartic_root( - ps=np.array(collision_coeffs), solver=solver - ) + ball_id, cushion_id = min(cache, key=lambda k: cache[k]) - ball_id, cushion_id = agent_ids[index] - ball, cushion = ( - shot.balls[ball_id], - shot.table.cushion_segments.circular[cushion_id], + return ball_circular_cushion_collision( + ball=shot.balls[ball_id], + cushion=shot.table.cushion_segments.circular[cushion_id], + time=cache[(ball_id, cushion_id)], ) - return ball_circular_cushion_collision(ball, cushion, shot.t + dtau_E) - -def get_next_ball_linear_cushion_collision(shot: System) -> Event: +def get_next_ball_linear_cushion_collision( + shot: System, collision_cache: CollisionCache +) -> Event: """Returns next ball-cushion collision (linear cushion segment)""" - dtau_E_min = np.inf - involved_agents = (Ball.dummy(), LinearCushionSegment.dummy()) + if not shot.table.has_linear_cushions: + return null_event(np.inf) - for ball in shot.balls.values(): - if ball.state.s in const.nontranslating: - continue + cache = collision_cache.times.setdefault(EventType.BALL_LINEAR_CUSHION, {}) + for ball in shot.balls.values(): state = ball.state params = ball.params for cushion in shot.table.cushion_segments.linear.values(): + obj_ids = (ball.id, cushion.id) + + if obj_ids in cache: + continue + + if ball.state.s in const.nontranslating: + cache[obj_ids] = np.inf + continue + dtau_E = solve.ball_linear_cushion_collision_time( rvw=state.rvw, s=state.s, @@ -463,32 +430,47 @@ def get_next_ball_linear_cushion_collision(shot: System) -> Event: R=params.R, ) - if dtau_E < dtau_E_min: - involved_agents = (ball, cushion) - dtau_E_min = dtau_E + cache[obj_ids] = shot.t + dtau_E - dtau_E = dtau_E_min + obj_ids = min(cache, key=lambda k: cache[k]) - return ball_linear_cushion_collision(*involved_agents, shot.t + dtau_E) + return ball_linear_cushion_collision( + ball=shot.balls[obj_ids[0]], + cushion=shot.table.cushion_segments.linear[obj_ids[1]], + time=cache[obj_ids], + ) def get_next_ball_pocket_collision( - shot: System, solver: QuarticSolver = QuarticSolver.HYBRID + shot: System, + collision_cache: CollisionCache, + solver: QuarticSolver = QuarticSolver.HYBRID, ) -> Event: """Returns next ball-pocket collision""" - dtau_E = np.inf - agent_ids = [] - collision_coeffs = [] + if not shot.table.has_pockets: + return null_event(np.inf) - for ball in shot.balls.values(): - if ball.state.s in const.nontranslating: - continue + ball_pocket_pairs: List[Tuple[str, str]] = [] + collision_coeffs: List[Tuple[float, ...]] = [] + + cache = collision_cache.times.setdefault(EventType.BALL_POCKET, {}) + for ball in shot.balls.values(): state = ball.state params = ball.params for pocket in shot.table.pockets.values(): + obj_ids = (ball.id, pocket.id) + + if obj_ids in cache: + continue + + if ball.state.s in const.nontranslating: + cache[obj_ids] = np.inf + continue + + ball_pocket_pairs.append(obj_ids) collision_coeffs.append( solve.ball_pocket_collision_coeffs( rvw=state.rvw, @@ -503,17 +485,17 @@ def get_next_ball_pocket_collision( ) ) - agent_ids.append((ball.id, pocket.id)) + if len(collision_coeffs): + roots = solve_quartics(ps=np.array(collision_coeffs), solver=solver) + for root, ball_pocket_pair in zip(roots, ball_pocket_pairs): + cache[ball_pocket_pair] = shot.t + root - if not len(collision_coeffs): - # There are no collisions to test for - return ball_pocket_collision(Ball.dummy(), Pocket.dummy(), shot.t + dtau_E) + # The cache is now populated and up-to-date - dtau_E, index = ptmath.roots.quartic.minimum_quartic_root( - ps=np.array(collision_coeffs), solver=solver - ) - - ball_id, pocket_id = agent_ids[index] - ball, pocket = shot.balls[ball_id], shot.table.pockets[pocket_id] + ball_id, pocket_id = min(cache, key=lambda k: cache[k]) - return ball_pocket_collision(ball, pocket, shot.t + dtau_E) + return ball_pocket_collision( + ball=shot.balls[ball_id], + pocket=shot.table.pockets[pocket_id], + time=cache[(ball_id, pocket_id)], + ) diff --git a/pooltool/evolution/event_based/test_simulate.py b/pooltool/evolution/event_based/test_simulate.py index 4f5cf4d1..20f50587 100644 --- a/pooltool/evolution/event_based/test_simulate.py +++ b/pooltool/evolution/event_based/test_simulate.py @@ -5,6 +5,7 @@ import pooltool.constants as const import pooltool.ptmath as ptmath from pooltool.events import EventType, ball_ball_collision, ball_pocket_collision +from pooltool.evolution.event_based.cache import CollisionCache from pooltool.evolution.event_based.simulate import ( get_next_ball_ball_collision, get_next_event, @@ -173,7 +174,7 @@ def test_case3(solver: quartic.QuarticSolver): expected = pytest.approx(5.810383731499328e-06, abs=1e-9) assert event.time == expected - assert quartic.minimum_quartic_root(coeffs_array, solver=solver)[0] == expected + assert quartic.solve_quartics(coeffs_array, solver=solver)[0] == expected @pytest.mark.parametrize( @@ -319,7 +320,7 @@ def _move_cue(system: System, phi: float) -> None: coeffs_array = np.array([coeffs], dtype=np.float64) - root = quartic.minimum_quartic_root(coeffs_array)[0] + root = quartic.solve_quartics(coeffs_array)[0] if phi < 90: assert root == np.inf @@ -423,7 +424,7 @@ def true_time_to_collision(eps, V0, mu_r, g): coeffs_array = np.array([coeffs], dtype=np.float64) truth = true_time_to_collision(eps, V0, ball1.params.u_r, ball1.params.g) - calculated = quartic.minimum_quartic_root(coeffs_array, solver=solver)[0] + calculated = quartic.solve_quartics(coeffs_array, solver=solver)[0] diff = abs(calculated - truth) assert diff < 10e-12 # Less than 10 femptosecond difference @@ -478,4 +479,7 @@ def test_no_ball_ball_collisions_for_intersecting_balls(solver: quartic.QuarticS assert ( get_next_event(system, quartic_solver=solver).event_type != EventType.BALL_BALL ) - assert get_next_ball_ball_collision(system, solver=solver).time == np.inf + assert ( + get_next_ball_ball_collision(system, CollisionCache(), solver=solver).time + == np.inf + ) diff --git a/pooltool/objects/table/datatypes.py b/pooltool/objects/table/datatypes.py index 6b3fbaf3..c3c7fdb0 100644 --- a/pooltool/objects/table/datatypes.py +++ b/pooltool/objects/table/datatypes.py @@ -102,6 +102,18 @@ def center(self) -> Tuple[float, float]: return self.w / 2, self.l / 2 + @property + def has_linear_cushions(self) -> bool: + return bool(len(self.cushion_segments.linear)) + + @property + def has_circular_cushions(self) -> bool: + return bool(len(self.cushion_segments.circular)) + + @property + def has_pockets(self) -> bool: + return bool(len(self.pockets)) + def copy(self) -> Table: """Create a copy.""" # Delegates the deep-ish copying of CushionSegments and Pocket to their respective diff --git a/pooltool/ptmath/roots/__init__.py b/pooltool/ptmath/roots/__init__.py index 725b35bd..e0fcacca 100644 --- a/pooltool/ptmath/roots/__init__.py +++ b/pooltool/ptmath/roots/__init__.py @@ -1,11 +1,9 @@ import pooltool.ptmath.roots.quadratic as quadratic import pooltool.ptmath.roots.quartic as quartic -from pooltool.ptmath.roots.core import min_real_root -from pooltool.ptmath.roots.quartic import minimum_quartic_root +from pooltool.ptmath.roots.quartic import solve_quartics __all__ = [ "quadratic", "quartic", - "min_real_root", - "minimum_quartic_root", + "solve_quartics", ] diff --git a/pooltool/ptmath/roots/core.py b/pooltool/ptmath/roots/core.py index 26b79bb1..c4559cdc 100644 --- a/pooltool/ptmath/roots/core.py +++ b/pooltool/ptmath/roots/core.py @@ -1,32 +1,25 @@ import numpy as np -from numba import jit from numpy.typing import NDArray -import pooltool.constants as const - -def min_real_root( +def get_real_positive_smallest_roots( roots: NDArray[np.complex128], abs_or_rel_cutoff: float = 1e-3, rtol: float = 1e-3, atol: float = 1e-9, -) -> np.complex128: - """Given an array of roots, find the minimum, real, positive root - - Note: This is faster than a numba vector implementation and a numba loop - implementation. +) -> NDArray[np.float64]: + """Returns the smallest postive and real root for each set of roots. Args: roots: - A 1D array of roots. + A mxn array of polynomial root solutions, where m is the number of equations + and n is the order of the polynomial. abs_or_rel_cutoff: - The criteria for a root being real depends on the magnitude of it's real + The criteria for a root being real depends on the magnitude of its real component. If it's large, we require the imaginary component to be less than atol in absolute terms. But when the real component is small, we require the imaginary component be less than a fraction, rtol, of the real component. - This is because when the real component is small, perhaps even comparable to - atol, using an absolute cutoff for the imaginary component doesn't make much - sense. abs_or_rel_cutoff defines a threshold for the magnitude of the real + abs_or_rel_cutoff defines a threshold for the magnitude of the real component, above which atol is used and below which rtol is used. atol: A root r (with abs(r.real) >= abs_or_rel_cutoff) is considered real if @@ -37,10 +30,9 @@ def min_real_root( the root is considered real if r.imag == 0, too. Returns: - root: - The root determined to be smallest, real, and positive. Note, a complex - datatype is returned, and it may have residual complex components. Use - root.real for only the real component. + An array of shape m. Each value is the smallest root that is real and + positive. If no such root exists (e.g. all roots are complex), then + `np.inf` is used. """ positive = roots.real >= 0.0 @@ -55,20 +47,10 @@ def min_real_root( small_keep2 = (real_mag == 0) & (imag_mag == 0) small_keep = (small_keep1 | small_keep2) & positive - candidates = roots[(small & small_keep) | (big & big_keep)] - - if candidates.size == 0: - return np.complex128(np.inf) - - # Return candidate with the smallest real component - return candidates[candidates.real.argmin()] + is_real = (small & small_keep) | (big & big_keep) + processed_roots = np.where(is_real, roots, np.complex128(np.inf)) + # Find the minimum real positive root in each row + min_real_positive_roots = np.min(processed_roots.real, axis=1) -@jit(nopython=True, cache=const.use_numba_cache) -def find_first_row_with_value(arr, X) -> int: - """Find the index of the first row in a 2D array that contains a specific value.""" - for i in range(arr.shape[0]): - for j in range(arr.shape[1]): - if arr[i, j] == X: - return i - return -1 + return min_real_positive_roots diff --git a/pooltool/ptmath/roots/quartic.py b/pooltool/ptmath/roots/quartic.py index 2bbe0eba..4f9c0242 100644 --- a/pooltool/ptmath/roots/quartic.py +++ b/pooltool/ptmath/roots/quartic.py @@ -5,7 +5,9 @@ from numpy.typing import NDArray import pooltool.constants as const -from pooltool.ptmath.roots.core import find_first_row_with_value, min_real_root +from pooltool.ptmath.roots.core import ( + get_real_positive_smallest_roots, +) from pooltool.utils.strenum import StrEnum, auto @@ -14,10 +16,10 @@ class QuarticSolver(StrEnum): NUMERIC = auto() -def minimum_quartic_root( +def solve_quartics( ps: NDArray[np.float64], solver: QuarticSolver = QuarticSolver.HYBRID -) -> Tuple[float, int]: - """Solves an array of quartic coefficients, returns smallest, real, positive root +) -> NDArray[np.float64]: + """Returns the smallest positive and real root for each quartic polynomial. Args: ps: @@ -29,24 +31,18 @@ def minimum_quartic_root( pooltool.ptmath.roots.quartic.QuarticSolver. Returns: - (real_root, index): - real_root is the minimum real root from the set of polynomials, and `index` - specifies the index of the responsible polynomial. i.e. the polynomial with - the root real_root is ps[index, :] + roots: + An array of shape m. Each value is the smallest root that is real and + positive. If no such root exists (e.g. all roots have complex), then + `np.inf` is returned. """ # Get the roots for the polynomials assert QuarticSolver(solver) - roots = _quartic_routine[solver](ps) - best_root = min_real_root(roots.flatten()) + roots = _quartic_routine[solver](ps) # Shape m x 4, dtype complex128 + best_roots = get_real_positive_smallest_roots(roots) # Shape m, dtype float64 - if best_root == np.inf: - return np.inf, 0 - - index = find_first_row_with_value(roots, best_root) - assert index > -1 - - return float(best_root.real), index + return best_roots def solve_many_numerical(p): diff --git a/pooltool/ptmath/roots/test_quartic.py b/pooltool/ptmath/roots/test_quartic.py index 40396905..9e054060 100644 --- a/pooltool/ptmath/roots/test_quartic.py +++ b/pooltool/ptmath/roots/test_quartic.py @@ -18,7 +18,7 @@ def test_case1(solver: quartic.QuarticSolver): expected = 0.048943195217641386 coeffs_array = np.array(coeffs)[np.newaxis, :] - assert quartic.minimum_quartic_root(coeffs_array, solver)[0] == pytest.approx( + assert quartic.solve_quartics(coeffs_array, solver)[0] == pytest.approx( expected, rel=1e-4 ) diff --git a/pyproject.toml b/pyproject.toml index a4d2d875..1a5373ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,7 @@ colored = ">=2.2.4" # TODO switch to something ubiquitous ipython = ">=8.18.1" # Publishing poetry-dynamic-versioning = {extras = ["plugin"], version = ">=1.4.0"} +ipdb = "^0.13.13" [tool.poetry.group.docs] optional = true