Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: base classes and jit #13

Merged
merged 5 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 37 additions & 6 deletions src/vector/_base.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,65 @@
"""Representation of coordinates in different systems."""

__all__ = ["AbstractVector", "AbstractVectorDifferential"]
__all__ = ["AbstractVectorBase", "AbstractVector", "AbstractVectorDifferential"]

import warnings
from abc import abstractmethod
from functools import partial
from typing import Any, TypeVar

import equinox as eqx
import jax

T = TypeVar("T", bound="AbstractVector")
BT = TypeVar("BT", bound="AbstractVectorBase")
VT = TypeVar("VT", bound="AbstractVector")
DT = TypeVar("DT", bound="AbstractVectorDifferential")


class AbstractVector(eqx.Module): # type: ignore[misc]
class AbstractVectorBase(eqx.Module): # type: ignore[misc]
"""Base class for all vector types."""

# ===============================================================
# Convenience methods

@abstractmethod
def represent_as(self, target: type[BT], /, *args: Any, **kwargs: Any) -> BT:
"""Represent the vector as another type."""
raise NotImplementedError


class AbstractVector(AbstractVectorBase):
"""Abstract representation of coordinates in different systems."""

def represent_as(self, target: type[T], /, **kwargs: Any) -> T:
# ===============================================================
# Convenience methods

@partial(jax.jit, static_argnums=1)
def represent_as(self, target: type[VT], /, *args: Any, **kwargs: Any) -> VT:
"""Represent the vector as another type."""
if any(args):
warnings.warn("Extra arguments are ignored.", UserWarning, stacklevel=2)

from ._transform import represent_as # pylint: disable=import-outside-toplevel

return represent_as(self, target, **kwargs)


class AbstractVectorDifferential(eqx.Module): # type: ignore[misc]
class AbstractVectorDifferential(AbstractVectorBase):
"""Abstract representation of vector differentials in different systems."""

vector_cls: eqx.AbstractClassVar[type[AbstractVector]]

# ===============================================================
# Convenience methods

@partial(jax.jit, static_argnums=1)
def represent_as(
self, target: type[DT], position: AbstractVector, /, **kwargs: Any
self, target: type[DT], position: AbstractVector, /, *args: Any, **kwargs: Any
) -> DT:
"""Represent the vector as another type."""
if any(args):
warnings.warn("Extra arguments are ignored.", UserWarning, stacklevel=2)

from ._transform import represent_as # pylint: disable=import-outside-toplevel

return represent_as(self, target, position, **kwargs)
40 changes: 34 additions & 6 deletions src/vector/_d1/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,36 @@

from plum import dispatch

from .base import Abstract1DVector
from .builtin import Cartesian1DVector, RadialVector
from vector._base import AbstractVector

from .base import Abstract1DVector, Abstract1DVectorDifferential
from .builtin import (
Cartesian1DVector,
CartesianDifferential1D,
RadialDifferential,
RadialVector,
)

###############################################################################
# 1D


@dispatch(precedence=1)
def represent_as(
current: Cartesian1DVector, target: type[Cartesian1DVector], /, **kwargs: Any
) -> Cartesian1DVector:
"""Self transform of 1D vectors."""
return current


@dispatch(precedence=1)
def represent_as(
current: RadialVector, target: type[RadialVector], /, **kwargs: Any
) -> RadialVector:
"""Self transform of 1D vectors."""
return current


@dispatch
def represent_as(
current: Abstract1DVector, target: type[Abstract1DVector], /, **kwargs: Any
Expand All @@ -26,12 +49,17 @@ def represent_as(


@dispatch.multi(
(Cartesian1DVector, type[Cartesian1DVector]), (RadialVector, type[RadialVector])
(CartesianDifferential1D, type[CartesianDifferential1D], AbstractVector),
(RadialDifferential, type[RadialDifferential], AbstractVector),
)
def represent_as(
current: Abstract1DVector, target: type[Abstract1DVector], /, **kwargs: Any
) -> Abstract1DVector:
"""Self transform of 1D vectors."""
current: Abstract1DVectorDifferential,
target: type[Abstract1DVectorDifferential],
position: AbstractVector,
/,
**kwargs: Any,
) -> Abstract1DVectorDifferential:
"""Self transform of 1D Differentials."""
return current


Expand Down
26 changes: 24 additions & 2 deletions src/vector/_d2/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,15 @@
import array_api_jax_compat as xp
from plum import dispatch

from .base import Abstract2DVector # pylint: disable=cyclic-import
from .builtin import Cartesian2DVector, PolarVector
from vector._base import AbstractVector

from .base import Abstract2DVector, Abstract2DVectorDifferential
from .builtin import (
Cartesian2DVector,
CartesianDifferential2D,
PolarDifferential,
PolarVector,
)


@dispatch
Expand All @@ -35,6 +42,21 @@ def represent_as(
return current


@dispatch.multi(
(CartesianDifferential2D, type[CartesianDifferential2D], AbstractVector),
(PolarDifferential, type[PolarDifferential], AbstractVector),
)
def represent_as(
current: Abstract2DVectorDifferential,
target: type[Abstract2DVectorDifferential],
position: AbstractVector,
/,
**kwargs: Any,
) -> Abstract2DVectorDifferential:
"""Self transform of 2D Differentials."""
return current


# @dispatch.multi(
# (Cartesian2DVector, type[LnPolarVector]),
# (Cartesian2DVector, type[Log10PolarVector]),
Expand Down
29 changes: 27 additions & 2 deletions src/vector/_d3/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,17 @@
import array_api_jax_compat as xp
from plum import dispatch

from .base import Abstract3DVector
from .builtin import Cartesian3DVector, CylindricalVector, SphericalVector
from vector._base import AbstractVector

from .base import Abstract3DVector, Abstract3DVectorDifferential
from .builtin import (
Cartesian3DVector,
CartesianDifferential3D,
CylindricalDifferential,
CylindricalVector,
SphericalDifferential,
SphericalVector,
)

###############################################################################
# 3D
Expand All @@ -34,6 +43,22 @@ def represent_as(
return current


@dispatch.multi(
(CartesianDifferential3D, type[CartesianDifferential3D], AbstractVector),
(SphericalDifferential, type[SphericalDifferential], AbstractVector),
(CylindricalDifferential, type[CylindricalDifferential], AbstractVector),
)
def represent_as(
current: Abstract3DVectorDifferential,
target: type[Abstract3DVectorDifferential],
position: AbstractVector,
/,
**kwargs: Any,
) -> Abstract3DVectorDifferential:
"""Self transform of 3D Differentials."""
return current


# =============================================================================
# Cartesian3DVector

Expand Down
56 changes: 11 additions & 45 deletions src/vector/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,48 +14,31 @@

from ._base import AbstractVector, AbstractVectorDifferential
from ._d1.base import Abstract1DVectorDifferential
from ._d1.builtin import (
Cartesian1DVector,
CartesianDifferential1D,
RadialDifferential,
RadialVector,
)
from ._d1.builtin import Cartesian1DVector, RadialVector
from ._d2.base import Abstract2DVector, Abstract2DVectorDifferential
from ._d2.builtin import (
Cartesian2DVector,
CartesianDifferential2D,
PolarDifferential,
PolarVector,
)
from ._d2.builtin import Cartesian2DVector, PolarVector
from ._d3.base import Abstract3DVector, Abstract3DVectorDifferential
from ._d3.builtin import (
Cartesian3DVector,
CartesianDifferential3D,
CylindricalDifferential,
CylindricalVector,
SphericalDifferential,
SphericalVector,
)
from ._d3.builtin import Cartesian3DVector, CylindricalVector, SphericalVector
from ._exceptions import IrreversibleDimensionChange
from ._utils import fields_and_values


# TODO: implement for cross-representations
@dispatch.multi(
@dispatch.multi( # type: ignore[misc]
# N-D -> N-D
(
Abstract1DVectorDifferential,
type[Abstract1DVectorDifferential],
type[Abstract1DVectorDifferential], # type: ignore[misc]
AbstractVector,
),
(
Abstract2DVectorDifferential,
type[Abstract2DVectorDifferential],
type[Abstract2DVectorDifferential], # type: ignore[misc]
AbstractVector,
),
(
Abstract3DVectorDifferential,
type[Abstract3DVectorDifferential],
type[Abstract3DVectorDifferential], # type: ignore[misc]
AbstractVector,
),
)
Expand All @@ -80,9 +63,7 @@ def represent_as(
# the correct numerator unit (of the Jacobian row). The value is a Vector of the
# original type, with fields that are the columns of that row, but with only the
# denomicator's units.
jac_nested_vecs = jax.vmap(jax.jacfwd(represent_as), in_axes=(0, None))(
current_position, target.vector_cls
)
jac_nested_vecs = jac_rep_as(current_position, target.vector_cls)

# This changes the Jacobian to be a dictionary of each row, with the value
# being that row's column as a dictionary, now with the correct units for
Expand All @@ -109,25 +90,10 @@ def represent_as(
)


# Self transform
@dispatch.multi(
(CartesianDifferential1D, type[CartesianDifferential1D], AbstractVector),
(RadialDifferential, type[RadialDifferential], AbstractVector),
(CartesianDifferential2D, type[CartesianDifferential2D], AbstractVector),
(PolarDifferential, type[PolarDifferential], AbstractVector),
(CartesianDifferential3D, type[CartesianDifferential3D], AbstractVector),
(SphericalDifferential, type[SphericalDifferential], AbstractVector),
(CylindricalDifferential, type[CylindricalDifferential], AbstractVector),
# TODO: situate this better to show how represent_as is used
jac_rep_as = jax.jit(
jax.vmap(jax.jacfwd(represent_as), in_axes=(0, None)), static_argnums=(1,)
)
def represent_as(
current: AbstractVectorDifferential,
target: type[AbstractVectorDifferential],
position: AbstractVector,
/,
**kwargs: Any,
) -> AbstractVectorDifferential:
"""Self transform."""
return current


###############################################################################
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ def _convert_jax_array(x: Shaped[Array, "*shape"], /) -> Float[Quantity, "*shape

def fields_and_values(obj: "DataclassInstance") -> Iterator[tuple[Field[Any], Any]]:
"""Return the fields and values of a dataclass instance."""
return ((f, getattr(obj, f.name)) for f in fields(obj))
yield from ((f, getattr(obj, f.name)) for f in fields(obj))
8 changes: 6 additions & 2 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,11 @@ def context_dimension_reduction(
return context


class AbstractVectorTest:
class AbstractVectorBaseTest:
"""Test :class:`vector.AbstractVectorBase`."""


class AbstractVectorTest(AbstractVectorBaseTest):
"""Test :class:`vector.AbstractVector`."""

@pytest.fixture(scope="class")
Expand All @@ -102,7 +106,7 @@ def test_represent_as(self, vector, target):
assert isinstance(newvec, target)


class AbstractVectorDifferentialTest:
class AbstractVectorDifferentialTest(AbstractVectorBaseTest):
"""Test :class:`vector.AbstractVectorDifferential`."""

@pytest.fixture(scope="class")
Expand Down
Loading