Skip to content

Commit

Permalink
Add some get_* methods to Event API (#174)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekiefl authored Jan 4, 2025
1 parent 899d667 commit 3d8fd67
Show file tree
Hide file tree
Showing 4 changed files with 248 additions and 23 deletions.
136 changes: 119 additions & 17 deletions pooltool/events/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from functools import partial
from typing import Any, Dict, Optional, Tuple, Type, Union
from typing import Any, Dict, Optional, Tuple, Type, Union, cast

from attrs import define, evolve, field
from cattrs.converters import Converter
Expand Down Expand Up @@ -75,6 +75,35 @@ def is_transition(self) -> bool:
EventType.SLIDING_ROLLING,
}

def has_ball(self) -> bool:
"""Returns True if this event type can involve a Ball."""
return (
self
in {
EventType.BALL_BALL,
EventType.BALL_LINEAR_CUSHION,
EventType.BALL_CIRCULAR_CUSHION,
EventType.BALL_POCKET,
EventType.STICK_BALL,
}
or self.is_transition()
)

def has_cushion(self) -> bool:
"""Returns True if this event type can involve a cushion (linear or circular)."""
return self in {
EventType.BALL_LINEAR_CUSHION,
EventType.BALL_CIRCULAR_CUSHION,
}

def has_pocket(self) -> bool:
"""Returns True if this event type can involve a Pocket."""
return self == EventType.BALL_POCKET

def has_stick(self) -> bool:
"""Returns True if this event type can involve a CueStick."""
return self == EventType.STICK_BALL


Object = Union[
NullObject,
Expand Down Expand Up @@ -185,22 +214,6 @@ def set_final(self, obj: Object) -> None:
else:
self.final = obj.copy()

def matches(self, obj: Object) -> bool:
"""Determines if the given object matches the agent.
It checks if the object is of the correct class type and if the IDs match.
Args:
obj: The object to compare with the agent.
Returns:
bool:
True if the object's class type and ID match the agent's type and ID,
False otherwise.
"""
correct_class = _class_to_type[type(obj)] == self.agent_type
return correct_class and obj.id == self.id

@staticmethod
def from_object(obj: Object, set_initial: bool = False) -> Agent:
"""Creates an agent instance from an object.
Expand Down Expand Up @@ -228,6 +241,17 @@ def copy(self) -> Agent:
"""Create a copy."""
return evolve(self)

def _get_state(self, initial: bool) -> Object:
"""Return either the initial or final state of the given agent.
Raises ValueError if that state is None.
"""
obj = self.initial if initial else self.final
if obj is None:
which = "initial" if initial else "final"
raise ValueError(f"Agent '{self.id}' has no {which} state in this event.")
return obj


def _disambiguate_agent_structuring(
uo: Dict[str, Any], _: Type[Agent], con: Converter
Expand Down Expand Up @@ -329,3 +353,81 @@ def copy(self) -> Event:
"""Create a copy."""
# NOTE is this deep-ish copy?
return evolve(self)

def _find_agent(self, agent_type: AgentType, agent_id: str) -> Agent:
"""Return the Agent with the specified agent_type and ID.
Raises:
ValueError if not found.
"""
for agent in self.agents:
if agent.agent_type == agent_type and agent.id == agent_id:
return agent
raise ValueError(
f"No agent of type {agent_type} with ID '{agent_id}' found in this event."
)

def get_ball(self, ball_id: str, initial: bool = True) -> Ball:
"""Return the Ball object with the given ID, either final or initial.
Args:
ball_id: The ID of the ball to retrieve.
initial: If True, return the ball's initial state; otherwise final state.
Raises:
ValueError: If the event does not involve a ball or if no matching ball is found.
"""
if not self.event_type.has_ball():
raise ValueError(
f"Event of type {self.event_type} does not involve a Ball."
)

agent = self._find_agent(AgentType.BALL, ball_id)
obj = agent._get_state(initial)
return cast(Ball, obj)

def get_pocket(self, pocket_id: str, initial: bool = True) -> Pocket:
"""Return the Pocket object with the given ID, either final or initial."""
if not self.event_type.has_pocket():
raise ValueError(
f"Event of type {self.event_type} does not involve a Pocket."
)

agent = self._find_agent(AgentType.POCKET, pocket_id)
obj = agent._get_state(initial)
return cast(Pocket, obj)

def get_cushion(
self, cushion_id: str
) -> Union[LinearCushionSegment, CircularCushionSegment]:
"""Return the cushion segment with the given ID."""
if not self.event_type.has_cushion():
raise ValueError(
f"Event of type {self.event_type} does not involve a cushion."
)

try:
agent = self._find_agent(AgentType.LINEAR_CUSHION_SEGMENT, cushion_id)
return cast(LinearCushionSegment, agent.initial)
except ValueError:
pass

try:
agent = self._find_agent(AgentType.CIRCULAR_CUSHION_SEGMENT, cushion_id)
return cast(CircularCushionSegment, agent.initial)
except ValueError:
pass

raise ValueError(
f"No agent of linear/circular cushion with ID '{cushion_id}' found in this event."
)

def get_stick(self, stick_id: str) -> Pocket:
"""Return the cue stick with the given ID."""
if not self.event_type.has_pocket():
raise ValueError(
f"Event of type {self.event_type} does not involve a Pocket."
)

agent = self._find_agent(AgentType.POCKET, stick_id)
return cast(Pocket, agent.initial)
7 changes: 1 addition & 6 deletions pooltool/evolution/continuize.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,7 @@ def continuize(system: System, dt: float = 0.01, inplace: bool = False) -> Syste

# We need to get the ball's outgoing state from the event. We'll
# evolve the system from this state.
for agent in events[count].agents:
if agent.matches(ball):
state = agent.final.state.copy() # type: ignore
break
else:
raise ValueError("No agents in event match ball")
state = events[count].get_ball(ball.id, initial=False).state.copy()

rvw, s = state.rvw, state.s

Expand Down
Binary file added tests/events/example_system.msgpack
Binary file not shown.
128 changes: 128 additions & 0 deletions tests/events/test_datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
from pathlib import Path
from typing import List

import pytest

from pooltool.events.datatypes import Event, EventType
from pooltool.objects.ball.datatypes import Ball
from pooltool.objects.table.components import (
CircularCushionSegment,
LinearCushionSegment,
Pocket,
)
from pooltool.system.datatypes import System


@pytest.fixture
def example_events() -> List[Event]:
"""
Returns the list of Event objects from simulating the example system.
"""
return System.load(Path(__file__).parent / "example_system.msgpack").events


def test_get_ball_success(example_events: List[Event]):
"""
Find an event that involves a ball (e.g. BALL_BALL or STICK_BALL)
and verify we can retrieve the ball by ID.
"""
# We'll look for a BALL_BALL event that (based on your snippet) should have agents: ('cue', '1')
event = next(e for e in example_events if e.event_type == EventType.BALL_BALL)

# Try retrieving the ball named "cue"
cue_ball = event.get_ball("cue", initial=False) # final state by default
assert isinstance(cue_ball, Ball)
assert cue_ball.id == "cue"

# Also retrieve the "1" ball by initial state
ball_1_initial = event.get_ball("1", initial=True)
assert isinstance(ball_1_initial, Ball)
assert ball_1_initial.id == "1"


def test_get_ball_no_ball_in_event(example_events: List[Event]):
"""
Attempt to retrieve a ball from an event type that doesn't involve a ball, expecting ValueError.
"""
null_event = example_events[0]
assert null_event.event_type == EventType.NONE

with pytest.raises(ValueError, match="does not involve a Ball"):
null_event.get_ball("dummy")


def test_get_ball_wrong_id(example_events: List[Event]):
"""
Attempt to retrieve a ball using an ID not present in a ball-involving event.
"""
event = next(e for e in example_events if e.event_type == EventType.STICK_BALL)

with pytest.raises(ValueError, match="No agent of type ball"):
event.get_ball("1")


def test_get_cushion_success(example_events: List[Event]):
"""
Find a BALL_LINEAR_CUSHION or BALL_CIRCULAR_CUSHION event and verify we can retrieve the cushion.
"""
# Agents: ('cue','6')
linear_event = next(
e for e in example_events if e.event_type == EventType.BALL_LINEAR_CUSHION
)

cushion_obj = linear_event.get_cushion("6")
assert isinstance(cushion_obj, LinearCushionSegment)
assert cushion_obj.id == "6"

# Agents ('cue', '8t')
circular_event = next(
e for e in example_events if e.event_type == EventType.BALL_CIRCULAR_CUSHION
)
cushion_obj_circ = circular_event.get_cushion("8t")
assert isinstance(cushion_obj_circ, CircularCushionSegment)
assert cushion_obj_circ.id == "8t"


def test_get_cushion_not_in_event(example_events: List[Event]):
"""
Attempt to retrieve a cushion from an event that doesn't involve one.
"""
event = next(e for e in example_events if e.event_type == EventType.BALL_BALL)
with pytest.raises(ValueError, match="does not involve a cushion"):
event.get_cushion("8t")


def test_get_pocket_success(example_events: List[Event]):
"""
Find a BALL_POCKET event (agents: ('1','rt') in your snippet) and retrieve the pocket.
"""
pocket_event = next(
e for e in example_events if e.event_type == EventType.BALL_POCKET
)
pocket_obj = pocket_event.get_pocket("rt", initial=False)
assert isinstance(pocket_obj, Pocket)
assert pocket_obj.id == "rt"


def test_get_pocket_not_in_event(example_events: List[Event]):
"""
Attempt to retrieve a pocket from a non-pocket event, expecting ValueError.
"""
event = next(e for e in example_events if e.event_type == EventType.BALL_BALL)
with pytest.raises(
ValueError, match="Event of type ball_ball does not involve a Pocket"
):
event.get_pocket("rt")


def test_get_pocket_missing_id(example_events: List[Event]):
"""
Attempt to retrieve a pocket with an ID that doesn't match the event's pocket.
"""
pocket_event = next(
e for e in example_events if e.event_type == EventType.BALL_POCKET
)
with pytest.raises(
ValueError, match="No agent of type pocket with ID 'non_existent_pocket_id'"
):
pocket_event.get_pocket("non_existent_pocket_id")

0 comments on commit 3d8fd67

Please sign in to comment.