Skip to content

Commit

Permalink
feat: dimension hints (#23)
Browse files Browse the repository at this point in the history
* build: opt deps all
* feat: type-enforce dimensions

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Feb 24, 2024
1 parent 0fa7b43 commit 27e7f9f
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 82 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@
"D203", # 1 blank line required before class docstring
"D213", # Multi-line docstring summary should start at the second line
"ERA001", # Commented out code
"F722", # Syntax error in forward annotation <- jaxtyping
"F722", # Syntax error in forward annotation <- jaxtyping
"F811", # redefinition of unused '...' <- plum-dispatch
"F821", # undefined name '...' <- jaxtyping
"FIX002", # Line contains TODO
Expand Down
6 changes: 6 additions & 0 deletions src/vector/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,12 @@
DT = TypeVar("DT", bound="AbstractVectorDifferential")


_0m = Quantity(0, "meter")
_0d = Quantity(0, "rad")
_pid = Quantity(xp.pi, "rad")
_2pid = Quantity(2 * xp.pi, "rad")


class AbstractVectorBase(eqx.Module): # type: ignore[misc]
"""Base class for all vector types.
Expand Down
42 changes: 42 additions & 0 deletions src/vector/_checks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Representation of coordinates in different systems."""

__all__: list[str] = []


import array_api_jax_compat as xp
import equinox as eqx
from jax_quantity import Quantity

from vector._typing import BatchableAngle, BatchableLength

_0m = Quantity(0, "meter")
_0d = Quantity(0, "rad")
_pid = Quantity(xp.pi, "rad")
_2pid = Quantity(2 * xp.pi, "rad")


def check_r_non_negative(r: BatchableLength) -> BatchableLength:
"""Check that the radial distance is non-negative."""
return eqx.error_if(
r,
xp.any(r < _0m),
"The radial distance must be non-negative.",
)


def check_phi_range(phi: BatchableAngle) -> BatchableAngle:
"""Check that the polar angle is in the range [0, 2pi)."""
return eqx.error_if(
phi,
xp.any((phi < _0d) | (phi >= _2pid)),
"The azimuthal (polar) angle must be in the range [0, 2pi).",
)


def check_theta_range(theta: BatchableAngle) -> BatchableAngle:
"""Check that the inclination angle is in the range [0, pi]."""
return eqx.error_if(
theta,
xp.any((theta < _0d) | (theta > _pid)),
"The inclination angle must be in the range [0, pi].",
)
23 changes: 14 additions & 9 deletions src/vector/_d1/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import equinox as eqx

from vector._typing import BatchableFloatScalarQ
from vector._checks import check_r_non_negative
from vector._typing import BatchableLength, BatchableSpeed
from vector._utils import converter_quantity_array

from .base import Abstract1DVector, Abstract1DVectorDifferential
Expand All @@ -26,16 +27,20 @@
class Cartesian1DVector(Abstract1DVector):
"""Cartesian vector representation."""

x: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
"""x coordinate."""
x: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""X coordinate :math:`x \in (-\infty,+\infty)`."""


@final
class RadialVector(Abstract1DVector):
"""Radial vector representation."""

r: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
"""Radial coordinate."""
r: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""Radial distance :math:`r \in [0,+\infty)`."""

def __check_init__(self) -> None:
"""Check the initialization."""
check_r_non_negative(self.r)


##############################################################################
Expand All @@ -46,8 +51,8 @@ class RadialVector(Abstract1DVector):
class CartesianDifferential1D(Abstract1DVectorDifferential):
"""Cartesian differential representation."""

d_x: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
"""Differential d_x/d_<>."""
d_x: BatchableSpeed = eqx.field(converter=converter_quantity_array)
r"""X differential :math:`dx/dt \in (-\infty,+\infty`)`."""

vector_cls: ClassVar[type[Cartesian1DVector]] = Cartesian1DVector # type: ignore[misc]

Expand All @@ -56,7 +61,7 @@ class CartesianDifferential1D(Abstract1DVectorDifferential):
class RadialDifferential(Abstract1DVectorDifferential):
"""Radial differential representation."""

d_r: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
"""Differential d_r/d_<>."""
d_r: BatchableSpeed = eqx.field(converter=converter_quantity_array)
r"""Radial speed :math:`dr/dt \in (-\infty,+\infty)`."""

vector_cls: ClassVar[type[RadialVector]] = RadialVector # type: ignore[misc]
41 changes: 32 additions & 9 deletions src/vector/_d2/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,13 @@

import equinox as eqx

from vector._typing import BatchableFloatScalarQ
from vector._checks import check_phi_range, check_r_non_negative
from vector._typing import (
BatchableAngle,
BatchableAngularSpeed,
BatchableLength,
BatchableSpeed,
)
from vector._utils import converter_quantity_array

from .base import Abstract2DVector, Abstract2DVectorDifferential
Expand All @@ -28,8 +34,11 @@
class Cartesian2DVector(Abstract2DVector):
"""Cartesian vector representation."""

x: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
y: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
x: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""X coordinate :math:`x \in (-\infty,+\infty)`."""

y: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""Y coordinate :math:`y \in (-\infty,+\infty)`."""


@final
Expand All @@ -39,8 +48,16 @@ class PolarVector(Abstract2DVector):
We use the symbol `phi` instead of `theta` to adhere to the ISO standard.
"""

r: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
phi: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
r: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""Radial distance :math:`r \in [0,+\infty)`."""

phi: BatchableAngle = eqx.field(converter=converter_quantity_array)
r"""Polar angle :math:`\phi \in [0,2\pi)`."""

def __check_init__(self) -> None:
"""Check the initialization."""
check_r_non_negative(self.r)
check_phi_range(self.phi)


# class LnPolarVector(Abstract2DVector):
Expand All @@ -64,8 +81,11 @@ class PolarVector(Abstract2DVector):
class CartesianDifferential2D(Abstract2DVectorDifferential):
"""Cartesian differential representation."""

d_x: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_y: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_x: BatchableSpeed = eqx.field(converter=converter_quantity_array)
r"""X coordinate differential :math:`\dot{x} \in (-\infty,+\infty)`."""

d_y: BatchableSpeed = eqx.field(converter=converter_quantity_array)
r"""Y coordinate differential :math:`\dot{y} \in (-\infty,+\infty)`."""

vector_cls: ClassVar[type[Cartesian2DVector]] = Cartesian2DVector # type: ignore[misc]

Expand All @@ -74,7 +94,10 @@ class CartesianDifferential2D(Abstract2DVectorDifferential):
class PolarDifferential(Abstract2DVectorDifferential):
"""Polar differential representation."""

d_r: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_phi: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_r: BatchableSpeed = eqx.field(converter=converter_quantity_array)
r"""Radial speed :math:`dr/dt \in [-\infty,+\infty]`."""

d_phi: BatchableAngularSpeed = eqx.field(converter=converter_quantity_array)
r"""Polar angular speed :math:`d\phi/dt \in [-\infty,+\infty]`."""

vector_cls: ClassVar[type[PolarVector]] = PolarVector # type: ignore[misc]
111 changes: 48 additions & 63 deletions src/vector/_d3/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,19 @@

from typing import ClassVar, final

import array_api_jax_compat as xp
import equinox as eqx
from jax_quantity import Quantity
from jaxtyping import Shaped

from vector._typing import BatchableFloatScalarQ
from vector._checks import check_phi_range, check_r_non_negative, check_theta_range
from vector._typing import (
BatchableAngle,
BatchableAngularSpeed,
BatchableLength,
BatchableSpeed,
)
from vector._utils import converter_quantity_array

from .base import Abstract3DVector, Abstract3DVectorDifferential

_0m = Quantity(0, "meter")
_0d = Quantity(0, "rad")
_pid = Quantity(xp.pi, "rad")
_2pid = Quantity(2 * xp.pi, "rad")

##############################################################################
# Position

Expand All @@ -36,81 +34,53 @@
class Cartesian3DVector(Abstract3DVector):
"""Cartesian vector representation."""

x: Shaped[Quantity["length"], "*#batch"] = eqx.field(
converter=converter_quantity_array
)
r"""X-coordinate :math:`x \in [-\infty, \infty]."""
x: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""X coordinate :math:`x \in (-\infty,+\infty)`."""

y: Shaped[Quantity["length"], "*#batch"] = eqx.field(
converter=converter_quantity_array
)
r"""Y-coordinate :math:`y \in [-\infty, \infty]."""
y: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""Y coordinate :math:`y \in (-\infty,+\infty)`."""

z: Shaped[Quantity["length"], "*#batch"] = eqx.field(
converter=converter_quantity_array
)
r"""Z-coordinate :math:`z \in [-\infty, \infty]."""
z: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""Z coordinate :math:`z \in (-\infty,+\infty)`."""


@final
class SphericalVector(Abstract3DVector):
"""Spherical vector representation."""

r: Shaped[Quantity["length"], "*#batch"] = eqx.field(
converter=converter_quantity_array
)
r: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""Radial distance :math:`r \in [0,+\infty)`."""

theta: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
theta: BatchableAngle = eqx.field(converter=converter_quantity_array)
r"""Inclination angle :math:`\phi \in [0,180]`."""

phi: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
phi: BatchableAngle = eqx.field(converter=converter_quantity_array)
r"""Azimuthal angle :math:`\phi \in [0,360)`."""

def __check_init__(self) -> None:
"""Check the validity of the initialisation."""
_ = eqx.error_if(
self.r,
xp.any(self.r < _0m),
"Radial distance 'r' must be in the range [0, +inf).",
)
_ = eqx.error_if(
self.theta,
xp.any((self.theta < _0d) | (self.theta > _pid)),
"Inclination 'theta' must be in the range [0, pi].",
)
_ = eqx.error_if(
self.phi,
xp.any((self.phi < _0d) | (self.phi >= _2pid)),
"Azimuthal angle 'phi' must be in the range [0, 2 * pi).",
)
check_r_non_negative(self.r)
check_theta_range(self.theta)
check_phi_range(self.phi)


@final
class CylindricalVector(Abstract3DVector):
"""Cylindrical vector representation."""

rho: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
rho: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""Cylindrical radial distance :math:`\rho \in [0,+\infty)`."""

phi: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
phi: BatchableAngle = eqx.field(converter=converter_quantity_array)
r"""Azimuthal angle :math:`\phi \in [0,360)`."""

z: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
z: BatchableLength = eqx.field(converter=converter_quantity_array)
r"""Height :math:`z \in (-\infty,+\infty)`."""

def __check_init__(self) -> None:
"""Check the validity of the initialisation."""
_ = eqx.error_if(
self.rho,
xp.any(self.rho < _0m),
"Cylindrical radial distance 'rho' must be in the range [0, +inf).",
)
_ = eqx.error_if(
self.phi,
xp.any((self.phi < _0d) | (self.phi >= _2pid)),
"Azimuthal angle 'phi' must be in the range [0, 2 * pi).",
)
check_r_non_negative(self.rho)
check_phi_range(self.phi)


##############################################################################
Expand All @@ -121,9 +91,14 @@ def __check_init__(self) -> None:
class CartesianDifferential3D(Abstract3DVectorDifferential):
"""Cartesian differential representation."""

d_x: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_y: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_z: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_x: BatchableSpeed = eqx.field(converter=converter_quantity_array)
r"""X speed :math:`dx/dt \in [-\infty, \infty]."""

d_y: BatchableSpeed = eqx.field(converter=converter_quantity_array)
r"""Y speed :math:`dy/dt \in [-\infty, \infty]."""

d_z: BatchableSpeed = eqx.field(converter=converter_quantity_array)
r"""Z speed :math:`dz/dt \in [-\infty, \infty]."""

vector_cls: ClassVar[type[Cartesian3DVector]] = Cartesian3DVector # type: ignore[misc]

Expand All @@ -132,9 +107,14 @@ class CartesianDifferential3D(Abstract3DVectorDifferential):
class SphericalDifferential(Abstract3DVectorDifferential):
"""Spherical differential representation."""

d_r: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_theta: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_phi: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_r: BatchableSpeed = eqx.field(converter=converter_quantity_array)
r"""Radial speed :math:`dr/dt \in [-\infty, \infty]."""

d_theta: BatchableAngularSpeed = eqx.field(converter=converter_quantity_array)
r"""Inclination speed :math:`d\theta/dt \in [-\infty, \infty]."""

d_phi: BatchableAngularSpeed = eqx.field(converter=converter_quantity_array)
r"""Azimuthal speed :math:`d\phi/dt \in [-\infty, \infty]."""

vector_cls: ClassVar[type[SphericalVector]] = SphericalVector # type: ignore[misc]

Expand All @@ -143,8 +123,13 @@ class SphericalDifferential(Abstract3DVectorDifferential):
class CylindricalDifferential(Abstract3DVectorDifferential):
"""Cylindrical differential representation."""

d_rho: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_phi: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_z: BatchableFloatScalarQ = eqx.field(converter=converter_quantity_array)
d_rho: BatchableSpeed = eqx.field(converter=converter_quantity_array)
r"""Cyindrical radial speed :math:`d\rho/dt \in [-\infty, \infty]."""

d_phi: BatchableAngularSpeed = eqx.field(converter=converter_quantity_array)
r"""Azimuthal speed :math:`d\phi/dt \in [-\infty, \infty]."""

d_z: BatchableSpeed = eqx.field(converter=converter_quantity_array)
r"""Vertical speed :math:`dz/dt \in [-\infty, \infty]."""

vector_cls: ClassVar[type[CylindricalVector]] = CylindricalVector # type: ignore[misc]
5 changes: 5 additions & 0 deletions src/vector/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,8 @@
FloatScalarQ = Float[Quantity, ""]
BatchFloatScalarQ = Shaped[FloatScalarQ, "*batch"]
BatchableFloatScalarQ = Shaped[FloatScalarQ, "*#batch"]

BatchableAngle = Shaped[Quantity["angle"], "*#batch"]
BatchableLength = Shaped[Quantity["length"], "*#batch"]
BatchableSpeed = Shaped[Quantity["speed"], "*#batch"]
BatchableAngularSpeed = Shaped[Quantity["angular speed"], "*#batch"]

0 comments on commit 27e7f9f

Please sign in to comment.