Skip to content

Commit

Permalink
🎨 style(ops): simplify ops (#356)
Browse files Browse the repository at this point in the history
* ♻️ refactor(ops): simplify dispatch
* 🎨 style(ops): simplify base operator
* ♻️ refactor(ops): update Pipe converter

Signed-off-by: Nathaniel Starkman <[email protected]>
  • Loading branch information
nstarman authored Jan 23, 2025
1 parent d2e2ecf commit 3525e2c
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 121 deletions.
2 changes: 1 addition & 1 deletion src/coordinax/_src/frames/coordinate.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def vconvert(target: type[AbstractPos], w: Coordinate, /) -> Coordinate:
# Transform operations


@AbstractOperator.__call__.dispatch # type: ignore[attr-defined, misc]
@AbstractOperator.__call__.dispatch # type: ignore[misc]
def call(self: AbstractOperator, x: Coordinate, /) -> Coordinate:
"""Dispatch to the operator's `__call__` method.
Expand Down
4 changes: 3 additions & 1 deletion src/coordinax/_src/operators/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
"GalileanTranslation",
# Misc
"VelocityBoost",
# Utils
"convert_to_pipe_operators",
]

from . import galilean
Expand All @@ -31,7 +33,7 @@
from .galilean.spatial_translation import GalileanSpatialTranslation
from .galilean.translation import GalileanTranslation
from .identity import Identity
from .pipe import Pipe
from .pipe import Pipe, convert_to_pipe_operators

# isort: split
from . import compat, register_simplify
Expand Down
154 changes: 73 additions & 81 deletions src/coordinax/_src/operators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from coordinax._src.vectors.base_pos import AbstractPos

if TYPE_CHECKING:
from coordinax.ops import Pipe
import coordinax.ops


class AbstractOperator(eqx.Module):
Expand All @@ -41,93 +41,20 @@ class AbstractOperator(eqx.Module):
# Constructors

@classmethod
@dispatch(precedence=-1)
@dispatch.abstract
def from_(
cls: "type[AbstractOperator]", *args: object, **kwargs: object
) -> "AbstractOperator":
"""Construct from a set of arguments.
This is a low-priority dispatch that will be called if no other
dispatch is found. It just tries to pass the arguments to the
constructor.
"""
return cls(*args, **kwargs)

@classmethod
@dispatch
def from_(
cls: "type[AbstractOperator]", obj: Mapping[str, Any], /
) -> "AbstractOperator":
"""Construct from a mapping.
Examples
--------
>>> import coordinax as cx
>>> pipe = cx.ops.Identity() | cx.ops.Identity()
>>> cx.ops.Pipe.from_({"operators": pipe})
Pipe((Identity(), Identity()))
"""
return cls(**obj)

@classmethod
@dispatch
def from_(
cls: "type[AbstractOperator]",
x: ArrayLike | list[float | int],
unit: str, # TODO: support unit object
/,
) -> "AbstractOperator":
"""Construct from a Quantity's value and unit.
Examples
--------
>>> import coordinax as cx
>>> op = cx.ops.GalileanSpatialTranslation.from_([1, 1, 1], "km")
>>> print(op.translation)
<CartesianPos3D (x[km], y[km], z[km])
[1 1 1]>
>>> op = cx.ops.GalileanTranslation.from_([3e5, 1, 1, 1], "km")
>>> print(op.translation)
<FourVector (t[s], q=(x[km], y[km], z[km]))
[1.001 1. 1. 1. ]>
>>> op = cx.ops.GalileanBoost.from_([1, 1, 1], "km/s")
>>> print(op.velocity)
<CartesianVel3D (d_x[km / s], d_y[km / s], d_z[km / s])
[1 1 1]>
"""
return cls(u.Quantity(x, unit))
"""Construct from a set of arguments."""
raise NotImplementedError # pragma: no cover

# ===========================================
# Operator API

@dispatch.abstract
def __call__(
self: "AbstractOperator",
x: AbstractPos, # noqa: ARG002
/,
**kwargs: Any, # noqa: ARG002
) -> AbstractPos:
def __call__(self: "AbstractOperator", *args: Any, **kwargs: Any) -> AbstractPos:
"""Apply the operator to the coordinates `x`."""
msg = "implement this method in the subclass"
raise TypeError(msg)

@dispatch.abstract
def __call__(
self: "AbstractOperator",
x: AbstractPos, # noqa: ARG002
t: u.Quantity["time"], # noqa: ARG002
/,
**kwargs: Any, # noqa: ARG002
) -> AbstractPos:
"""Apply the operator to the coordinates `x` at a time `t`."""
msg = "implement this method in the subclass"
raise TypeError(msg)
raise NotImplementedError # pragma: no cover

# -------------------------------------------

Expand Down Expand Up @@ -216,7 +143,7 @@ def __str__(self) -> str:
# ===========================================
# Operator Composition

def __or__(self, other: "AbstractOperator") -> "Pipe":
def __or__(self, other: "AbstractOperator") -> "coordinax.ops.Pipe":
"""Compose with another operator.
Examples
Expand All @@ -241,7 +168,72 @@ def __or__(self, other: "AbstractOperator") -> "Pipe":
return Pipe((self, other))


@AbstractOperator.from_.dispatch # type: ignore[attr-defined, misc]
# ============================================================
# Constructors


@AbstractOperator.from_.dispatch(precedence=-1)
def from_(
cls: type[AbstractOperator], *args: object, **kwargs: object
) -> AbstractOperator:
"""Construct from a set of arguments.
This is a low-priority dispatch that will be called if no other
dispatch is found. It just tries to pass the arguments to the
constructor.
"""
return cls(*args, **kwargs)


@AbstractOperator.from_.dispatch
def from_(cls: type[AbstractOperator], obj: Mapping[str, Any], /) -> AbstractOperator:
"""Construct from a mapping.
Examples
--------
>>> import coordinax as cx
>>> pipe = cx.ops.Identity() | cx.ops.Identity()
>>> cx.ops.Pipe.from_({"operators": pipe})
Pipe((Identity(), Identity()))
"""
return cls(**obj)


@AbstractOperator.from_.dispatch
def from_(
cls: type[AbstractOperator],
x: ArrayLike | list[float | int],
unit: str, # TODO: support unit object
/,
) -> AbstractOperator:
"""Construct from a Quantity's value and unit.
Examples
--------
>>> import coordinax as cx
>>> op = cx.ops.GalileanSpatialTranslation.from_([1, 1, 1], "km")
>>> print(op.translation)
<CartesianPos3D (x[km], y[km], z[km])
[1 1 1]>
>>> op = cx.ops.GalileanTranslation.from_([3e5, 1, 1, 1], "km")
>>> print(op.translation)
<FourVector (t[s], q=(x[km], y[km], z[km]))
[1.001 1. 1. 1. ]>
>>> op = cx.ops.GalileanBoost.from_([1, 1, 1], "km/s")
>>> print(op.velocity)
<CartesianVel3D (d_x[km / s], d_y[km / s], d_z[km / s])
[1 1 1]>
"""
return cls(u.Quantity(x, unit))


@AbstractOperator.from_.dispatch
def from_(cls: type[AbstractOperator], obj: AbstractOperator, /) -> AbstractOperator:
"""Construct an operator from another operator.
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_src/operators/boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def inverse(self) -> "VelocityBoost":

# -----------------------------------------------------

@AbstractOperator.__call__.dispatch # type: ignore[attr-defined, misc]
@AbstractOperator.__call__.dispatch # type: ignore[misc]
def __call__(self: "VelocityBoost", p: AbstractVel, /) -> AbstractVel:
"""Apply the boost to the coordinates.
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_src/operators/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def inverse(self: "AbstractCompositeOperator") -> "coordinax.ops.Pipe":

return Pipe(tuple(op.inverse for op in reversed(self.operators)))

@AbstractOperator.__call__.dispatch(precedence=1) # type: ignore[attr-defined, misc]
@AbstractOperator.__call__.dispatch(precedence=1) # type: ignore[misc]
def __call__(
self: "AbstractCompositeOperator", *args: object, **kwargs: Any
) -> tuple[object, ...]:
Expand Down Expand Up @@ -114,7 +114,7 @@ def __iter__(self: HasOperatorsAttr) -> Iterator[AbstractOperator]:
# Call dispatches


@AbstractOperator.__call__.dispatch(precedence=1) # type: ignore[attr-defined, misc]
@AbstractOperator.__call__.dispatch(precedence=1) # type: ignore[misc]
def call(
self: AbstractCompositeOperator, x: AbstractVector, /, **kwargs: Any
) -> AbstractVector:
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_src/operators/galilean/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def from_euler(
return cls(rotation=R)

@classmethod
@AbstractOperator.from_.dispatch # type: ignore[attr-defined, misc]
@AbstractOperator.from_.dispatch # type: ignore[misc]
def from_(cls: "type[GalileanRotation]", obj: Rotation, /) -> "GalileanRotation":
"""Initialize from a `jax.scipy.spatial.transform.Rotation`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def inverse(self) -> "GalileanSpatialTranslation":

# -------------------------------------------

@AbstractOperator.__call__.dispatch # type: ignore[attr-defined, misc]
@AbstractOperator.__call__.dispatch # type: ignore[misc]
def __call__(
self: "GalileanSpatialTranslation", q: AbstractPos, /, **__: Any
) -> AbstractPos:
Expand Down
78 changes: 65 additions & 13 deletions src/coordinax/_src/operators/pipe.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,21 @@
"""Sequence of Operators."""

__all__ = ["Pipe"]
__all__ = ["Pipe", "convert_to_pipe_operators"]

import textwrap
from dataclasses import replace
from typing import Any, final

import equinox as eqx
from plum import dispatch

from .base import AbstractOperator
from .composite import AbstractCompositeOperator


def _converter_seq(inp: Any) -> tuple[AbstractOperator, ...]:
if isinstance(inp, tuple):
return inp
if isinstance(inp, list):
return tuple(inp)
if isinstance(inp, Pipe):
return inp.operators
if isinstance(inp, AbstractOperator):
return (inp,)

raise TypeError
@dispatch.abstract
def convert_to_pipe_operators(inp: Any, /) -> tuple[AbstractOperator, ...]:
raise NotImplementedError # pragma: no cover


@final
Expand Down Expand Up @@ -97,7 +90,9 @@ class Pipe(AbstractCompositeOperator):
"""

operators: tuple[AbstractOperator, ...] = eqx.field(converter=_converter_seq)
operators: tuple[AbstractOperator, ...] = eqx.field(
converter=convert_to_pipe_operators
)

# ---------------------------------------------------------------

Expand Down Expand Up @@ -133,3 +128,60 @@ def __repr__(self) -> str:
if "\n" in ops:
ops = "(\n" + textwrap.indent(ops[1:-1], " ") + "\n)"
return f"{self.__class__.__name__}({ops})"


# ==============================================================
# Constructor


@dispatch
def convert_to_pipe_operators(
inp: tuple[AbstractOperator, ...] | list[AbstractOperator],
) -> tuple[AbstractOperator, ...]:
"""Convert to a tuple of operators.
Examples
--------
>>> import coordinax as cx
>>> op1 = cx.ops.GalileanRotation([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
>>> op2 = cx.ops.Identity()
>>> convert_to_pipe_operators((op1, op2))
(GalileanRotation(rotation=i32[3,3]), Identity())
"""
return tuple(inp)


@dispatch
def convert_to_pipe_operators(inp: AbstractOperator) -> tuple[AbstractOperator, ...]:
"""Convert to a tuple of operators.
Examples
--------
>>> import coordinax as cx
>>> op1 = cx.ops.GalileanRotation([[1, 0, 0], [0, 1, 0], [0, 0, 1]])
>>> convert_to_pipe_operators(op1)
(GalileanRotation(rotation=i32[3,3]),)
"""
return (inp,)


@dispatch
def convert_to_pipe_operators(inp: Pipe) -> tuple[AbstractOperator, ...]:
"""Convert to a tuple of operators.
Examples
--------
>>> import coordinax as cx
>>> op1 = cx.ops.Identity()
>>> op2 = cx.ops.Identity()
>>> pipe = cx.ops.Pipe((op1, op2))
>>> convert_to_pipe_operators(pipe)
(Identity(), Identity())
"""
return inp.operators
Loading

0 comments on commit 3525e2c

Please sign in to comment.