Skip to content

Commit

Permalink
feat: use Distance (#79)
Browse files Browse the repository at this point in the history
* feat: use Distance

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Mar 28, 2024
1 parent 5f9956e commit 0165f9a
Show file tree
Hide file tree
Showing 12 changed files with 68 additions and 77 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"jaxtyping",
"quax>=0.0.3",
"quaxed >= 0.3",
"unxt >= 0.6",
"unxt >= 0.8",
]
description = "Vectors in JAX"
dynamic = ["version"]
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,7 @@ def to_units(
... phi=Quantity(3, "rad"))
>>> sph.to_units({"length": "km", "angle": "deg"})
SphericalVector(
r=Quantity[PhysicalType('length')](value=f32[], unit=Unit("km")),
r=Distance(value=f32[], unit=Unit("km")),
phi=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("deg")),
theta=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("deg"))
)
Expand Down Expand Up @@ -666,7 +666,7 @@ def to_units(
>>> sph.to_units(ToUnitsOptions.consistent)
SphericalVector(
r=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
r=Distance(value=f32[], unit=Unit("m")),
phi=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("rad")),
theta=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("rad"))
)
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_base_dif.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def __mul__(
>>> dr = RadialDifferential(Quantity(1, "m/s"))
>>> vec = dr * Quantity(2, "s")
>>> vec
RadialVector(r=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")))
RadialVector(r=Distance(value=f32[], unit=Unit("m")))
>>> vec.r
Quantity['length'](Array(2., dtype=float32), unit='m')
Distance(Array(2., dtype=float32), unit='m')
"""
return self.integral_cls.constructor(
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_base_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,12 +172,12 @@ def represent_as(self, target: type[VT], /, *args: Any, **kwargs: Any) -> VT:
>>> sph = vec.represent_as(SphericalVector)
>>> sph
SphericalVector(
r=Quantity[PhysicalType('length')](value=f32[], unit=Unit("m")),
r=Distance(value=f32[], unit=Unit("m")),
phi=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("rad")),
theta=Quantity[PhysicalType('angle')](value=f32[], unit=Unit("rad"))
)
>>> sph.r
Quantity['length'](Array(3.7416575, dtype=float32), unit='m')
Distance(Array(3.7416575, dtype=float32), unit='m')
"""
if any(args):
Expand Down
18 changes: 9 additions & 9 deletions src/coordinax/_d1/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@
import jax

import quaxed.array_api as xp
from unxt import Quantity
from unxt import Distance, Quantity

import coordinax._typing as ct
from .base import Abstract1DVector, Abstract1DVectorDifferential
from coordinax._base_vec import AbstractVector
from coordinax._checks import check_r_non_negative
from coordinax._typing import BatchableLength, BatchableSpeed
from coordinax._utils import classproperty

##############################################################################
Expand All @@ -33,7 +33,7 @@
class Cartesian1DVector(Abstract1DVector):
"""Cartesian vector representation."""

x: BatchableLength = eqx.field(
x: ct.BatchableLength = eqx.field(
converter=partial(Quantity["length"].constructor, dtype=float)
)
r"""X coordinate :math:`x \in (-\infty,+\infty)`."""
Expand Down Expand Up @@ -118,7 +118,7 @@ def __sub__(self, other: Any, /) -> "Cartesian1DVector":
return replace(self, x=self.x - cart.x)

@partial(jax.jit)
def norm(self) -> BatchableLength:
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Examples
Expand All @@ -138,8 +138,8 @@ def norm(self) -> BatchableLength:
class RadialVector(Abstract1DVector):
"""Radial vector representation."""

r: BatchableLength = eqx.field(
converter=partial(Quantity["length"].constructor, dtype=float)
r: ct.BatchableDistance = eqx.field(
converter=partial(Distance.constructor, dtype=float)
)
r"""Radial distance :math:`r \in [0,+\infty)`."""

Expand All @@ -161,7 +161,7 @@ def differential_cls(cls) -> type["RadialDifferential"]:
class CartesianDifferential1D(Abstract1DVectorDifferential):
"""Cartesian differential representation."""

d_x: BatchableSpeed = eqx.field(converter=Quantity["speed"].constructor)
d_x: ct.BatchableSpeed = eqx.field(converter=Quantity["speed"].constructor)
r"""X differential :math:`dx/dt \in (-\infty,+\infty`)`."""

@classproperty
Expand All @@ -170,7 +170,7 @@ def integral_cls(cls) -> type[Cartesian1DVector]:
return Cartesian1DVector

@partial(jax.jit)
def norm(self, _: Abstract1DVector | None = None, /) -> BatchableSpeed:
def norm(self, _: Abstract1DVector | None = None, /) -> ct.BatchableSpeed:
"""Return the norm of the vector.
Examples
Expand All @@ -189,7 +189,7 @@ def norm(self, _: Abstract1DVector | None = None, /) -> BatchableSpeed:
class RadialDifferential(Abstract1DVectorDifferential):
"""Radial differential representation."""

d_r: BatchableSpeed = eqx.field(converter=Quantity["speed"].constructor)
d_r: ct.BatchableSpeed = eqx.field(converter=Quantity["speed"].constructor)
r"""Radial speed :math:`dr/dt \in (-\infty,+\infty)`."""

@classproperty
Expand Down
35 changes: 15 additions & 20 deletions src/coordinax/_d2/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,13 @@
import jax

import quaxed.array_api as xp
from unxt import Quantity
from unxt import Distance, Quantity

import coordinax._typing as ct
from .base import Abstract2DVector, Abstract2DVectorDifferential
from coordinax._base_vec import AbstractVector
from coordinax._checks import check_phi_range, check_r_non_negative
from coordinax._converters import converter_phi_to_range
from coordinax._typing import (
BatchableAngle,
BatchableAngularSpeed,
BatchableLength,
BatchableSpeed,
)
from coordinax._utils import classproperty

# =============================================================================
Expand All @@ -39,12 +34,12 @@
class Cartesian2DVector(Abstract2DVector):
"""Cartesian vector representation."""

x: BatchableLength = eqx.field(
x: ct.BatchableLength = eqx.field(
converter=partial(Quantity["length"].constructor, dtype=float)
)
r"""X coordinate :math:`x \in (-\infty,+\infty)`."""

y: BatchableLength = eqx.field(
y: ct.BatchableLength = eqx.field(
converter=partial(Quantity["length"].constructor, dtype=float)
)
r"""Y coordinate :math:`y \in (-\infty,+\infty)`."""
Expand Down Expand Up @@ -118,7 +113,7 @@ def __sub__(self, other: Any, /) -> "Cartesian2DVector":
return replace(self, x=self.x - cart.x, y=self.y - cart.y)

@partial(jax.jit)
def norm(self) -> BatchableLength:
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Examples
Expand All @@ -140,12 +135,12 @@ class PolarVector(Abstract2DVector):
We use the symbol `phi` instead of `theta` to adhere to the ISO standard.
"""

r: BatchableLength = eqx.field(
converter=partial(Quantity["length"].constructor, dtype=float)
r: ct.BatchableDistance = eqx.field(
converter=partial(Distance.constructor, dtype=float)
)
r"""Radial distance :math:`r \in [0,+\infty)`."""

phi: BatchableAngle = eqx.field(
phi: ct.BatchableAngle = eqx.field(
converter=lambda x: converter_phi_to_range(
Quantity["angle"].constructor(x, dtype=float) # pylint: disable=E1120
)
Expand All @@ -163,7 +158,7 @@ def differential_cls(cls) -> type["PolarDifferential"]:
return PolarDifferential

@partial(jax.jit)
def norm(self) -> BatchableLength:
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Examples
Expand All @@ -172,7 +167,7 @@ def norm(self) -> BatchableLength:
>>> import coordinax as cx
>>> q = cx.PolarVector(r=Quantity(3, "kpc"), phi=Quantity(90, "deg"))
>>> q.norm()
Quantity['length'](Array(3., dtype=float32), unit='kpc')
Distance(Array(3., dtype=float32), unit='kpc')
"""
return self.r
Expand All @@ -185,12 +180,12 @@ def norm(self) -> BatchableLength:
class CartesianDifferential2D(Abstract2DVectorDifferential):
"""Cartesian differential representation."""

d_x: BatchableSpeed = eqx.field(
d_x: ct.BatchableSpeed = eqx.field(
converter=partial(Quantity["speed"].constructor, dtype=float)
)
r"""X coordinate differential :math:`\dot{x} \in (-\infty,+\infty)`."""

d_y: BatchableSpeed = eqx.field(
d_y: ct.BatchableSpeed = eqx.field(
converter=partial(Quantity["speed"].constructor, dtype=float)
)
r"""Y coordinate differential :math:`\dot{y} \in (-\infty,+\infty)`."""
Expand All @@ -201,7 +196,7 @@ def integral_cls(cls) -> type[Cartesian2DVector]:
return Cartesian2DVector

@partial(jax.jit)
def norm(self, _: Abstract2DVector | None = None, /) -> BatchableSpeed:
def norm(self, _: Abstract2DVector | None = None, /) -> ct.BatchableSpeed:
"""Return the norm of the vector.
Examples
Expand All @@ -220,12 +215,12 @@ def norm(self, _: Abstract2DVector | None = None, /) -> BatchableSpeed:
class PolarDifferential(Abstract2DVectorDifferential):
"""Polar differential representation."""

d_r: BatchableSpeed = eqx.field(
d_r: ct.BatchableSpeed = eqx.field(
converter=partial(Quantity["speed"].constructor, dtype=float)
)
r"""Radial speed :math:`dr/dt \in [-\infty,+\infty]`."""

d_phi: BatchableAngularSpeed = eqx.field(
d_phi: ct.BatchableAngularSpeed = eqx.field(
converter=partial(Quantity["angular speed"].constructor, dtype=float)
)
r"""Polar angular speed :math:`d\phi/dt \in [-\infty,+\infty]`."""
Expand Down
Loading

0 comments on commit 0165f9a

Please sign in to comment.