Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

build: unxt #60

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"jax",
"jaxlib",
"jaxtyping",
"jax_quantity @ git+https://github.com/GalacticDynamics/jax-quantity.git",
"unxt @ git+https://github.com/GalacticDynamics/unxt.git",
"quax>=0.0.3",
]
description = "Vectors in JAX"
Expand Down Expand Up @@ -188,7 +188,7 @@

[tool.ruff.lint.isort]
combine-as-imports = true
known-first-party = ["quaxed", "jax_quantity"]
known-first-party = ["quaxed", "unxt"]
known-local-folder = ["coordinax"]


Expand Down
40 changes: 20 additions & 20 deletions src/coordinax/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from plum import dispatch

import quaxed.array_api as xp
from jax_quantity import Quantity
from unxt import Quantity

from ._utils import classproperty, dataclass_items, dataclass_values, full_shaped
from coordinax._typing import Unit
Expand Down Expand Up @@ -83,7 +83,7 @@ def constructor(
Examples
--------
>>> import jax.numpy as jnp
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian3DVector

>>> xs = {"x": Quantity(1, "m"), "y": Quantity(2, "m"), "z": Quantity(3, "m")}
Expand Down Expand Up @@ -127,7 +127,7 @@ def constructor(
Examples
--------
>>> import jax.numpy as jnp
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian3DVector

>>> xs = Quantity([1, 2, 3], "meter")
Expand Down Expand Up @@ -183,7 +183,7 @@ def __getitem__(self, index: Any) -> "Self":
--------
We assume the following imports:

>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian2DVector

We can slice a vector:
Expand All @@ -205,7 +205,7 @@ def mT(self) -> "Self": # noqa: N802
--------
We assume the following imports:

>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian3DVector

We can transpose a vector:
Expand All @@ -228,7 +228,7 @@ def ndim(self) -> int:
--------
We assume the following imports:

>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian1DVector

We can get the number of dimensions of a vector:
Expand Down Expand Up @@ -265,7 +265,7 @@ def shape(self) -> Any:
--------
We assume the following imports:

>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian1DVector

We can get the shape of a vector:
Expand Down Expand Up @@ -299,7 +299,7 @@ def size(self) -> int:
--------
We assume the following imports:

>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian1DVector

We can get the size of a vector:
Expand Down Expand Up @@ -333,7 +333,7 @@ def T(self) -> "Self": # noqa: N802
--------
We assume the following imports:

>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian3DVector

We can transpose a vector:
Expand Down Expand Up @@ -366,7 +366,7 @@ def to_device(self, device: None | Device = None) -> "Self":
We assume the following imports:

>>> from jax import devices
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian1DVector

We can move a vector to a new device:
Expand All @@ -392,7 +392,7 @@ def flatten(self) -> "Self":
--------
We assume the following imports:

>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian2DVector

We can flatten a vector:
Expand Down Expand Up @@ -427,7 +427,7 @@ def reshape(self, *args: Any, order: str = "C") -> "Self":
--------
We assume the following imports:

>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian2DVector

We can reshape a vector:
Expand Down Expand Up @@ -477,7 +477,7 @@ def asdict(
--------
We assume the following imports:

>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian2DVector

We can get the vector as a mapping:
Expand Down Expand Up @@ -568,7 +568,7 @@ def to_units(
--------
We assume the following imports:

>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian2DVector, SphericalVector

We can convert a vector to the given units:
Expand Down Expand Up @@ -642,7 +642,7 @@ def constructor( # noqa: D417
Examples
--------
>>> import jax.numpy as jnp
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian3DVector

>>> x, y, z = Quantity(1, "meter"), Quantity(2, "meter"), Quantity(3, "meter")
Expand Down Expand Up @@ -808,7 +808,7 @@ def represent_as(self, target: type[VT], /, *args: Any, **kwargs: Any) -> VT:
--------
We assume the following imports:

>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian3DVector, SphericalVector

We can represent a vector as another type:
Expand Down Expand Up @@ -845,7 +845,7 @@ def norm(self) -> Quantity["length"]:
Examples
--------
We assume the following imports:
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian3DVector

We can compute the norm of a vector
Expand Down Expand Up @@ -898,7 +898,7 @@ def __neg__(self) -> "Self":

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import RadialDifferential
>>> dr = RadialDifferential(Quantity(1, "m/s"))
>>> -dr
Expand All @@ -922,11 +922,11 @@ def __neg__(self) -> "Self":
def __mul__(
self: "AbstractVectorDifferential", other: Quantity
) -> "AbstractVector":
"""Multiply the vector by a :class:`jax_quantity.Quantity`.
"""Multiply the vector by a :class:`unxt.Quantity`.

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import RadialDifferential

>>> dr = RadialDifferential(Quantity(1, "m/s"))
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import equinox as eqx

import quaxed.array_api as xp
from jax_quantity import Quantity
from unxt import Quantity

from coordinax._typing import BatchableAngle, BatchableLength

Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_d1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from jaxtyping import Shaped

from jax_quantity import Quantity
from unxt import Quantity

from coordinax._base import (
AbstractVector,
Expand Down Expand Up @@ -44,7 +44,7 @@ def constructor(

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian1DVector

>>> q = Cartesian1DVector.constructor(Quantity(1, "kpc"))
Expand Down
12 changes: 6 additions & 6 deletions src/coordinax/_d1/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax

import quaxed.array_api as xp
from jax_quantity import Quantity
from unxt import Quantity

from .base import Abstract1DVector, Abstract1DVectorDifferential
from coordinax._base import AbstractVector
Expand Down Expand Up @@ -51,7 +51,7 @@ def __neg__(self) -> "Self":

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian1DVector
>>> q = Cartesian1DVector.constructor(Quantity([1], "kpc"))
>>> -q
Expand All @@ -70,7 +70,7 @@ def __add__(self, other: Any, /) -> "Cartesian1DVector":

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian1DVector, RadialVector

>>> q = Cartesian1DVector.constructor(Quantity([1], "kpc"))
Expand All @@ -96,7 +96,7 @@ def __sub__(self, other: Any, /) -> "Cartesian1DVector":

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian1DVector, RadialVector

>>> q = Cartesian1DVector.constructor(Quantity([1], "kpc"))
Expand All @@ -123,7 +123,7 @@ def norm(self) -> BatchableLength:

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian1DVector, RadialVector

>>> q = Cartesian1DVector.constructor(Quantity([-1], "kpc"))
Expand Down Expand Up @@ -175,7 +175,7 @@ def norm(self, _: Abstract1DVector | None = None, /) -> BatchableSpeed:

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import CartesianDifferential1D
>>> q = CartesianDifferential1D.constructor(Quantity([-1], "km/s"))
>>> q.norm()
Expand Down
6 changes: 3 additions & 3 deletions src/coordinax/_d1/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from plum import conversion_method

import quaxed.array_api as xp
from jax_quantity import Quantity
from unxt import Quantity

from .base import Abstract1DVector
from .builtin import Cartesian1DVector, CartesianDifferential1D
Expand All @@ -19,7 +19,7 @@

@conversion_method(type_from=Abstract1DVector, type_to=Quantity) # type: ignore[misc]
def vec_to_q(obj: Abstract1DVector, /) -> Shaped[Quantity["length"], "*batch 1"]:
"""`coordinax.Abstract1DVector` -> `jax_quantity.Quantity`."""
"""`coordinax.Abstract1DVector` -> `unxt.Quantity`."""
cart = full_shaped(obj.represent_as(Cartesian1DVector))
return xp.stack(tuple(dataclass_values(cart)), axis=-1)

Expand All @@ -28,5 +28,5 @@ def vec_to_q(obj: Abstract1DVector, /) -> Shaped[Quantity["length"], "*batch 1"]
def vec_diff_to_q(
obj: CartesianDifferential1D, /
) -> Shaped[Quantity["speed"], "*batch 1"]:
"""`coordinax.CartesianDifferential1D` -> `jax_quantity.Quantity`."""
"""`coordinax.CartesianDifferential1D` -> `unxt.Quantity`."""
return xp.stack(tuple(dataclass_values(full_shaped(obj))), axis=-1)
2 changes: 1 addition & 1 deletion src/coordinax/_d1/operate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from jaxtyping import Shaped
from plum import convert

from jax_quantity import Quantity
from unxt import Quantity

from .builtin import Cartesian1DVector
from coordinax._typing import TimeBatchOrScalar
Expand Down
14 changes: 7 additions & 7 deletions src/coordinax/_d2/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import jax

import quaxed.array_api as xp
from jax_quantity import Quantity
from unxt import Quantity

from .base import Abstract2DVector, Abstract2DVectorDifferential
from coordinax._base import AbstractVector
Expand Down Expand Up @@ -61,7 +61,7 @@ def __neg__(self) -> "Self":

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian2DVector

>>> q = Cartesian2DVector.constructor(Quantity([1, 2], "kpc"))
Expand All @@ -79,7 +79,7 @@ def __add__(self, other: Any, /) -> "Cartesian2DVector":

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian2DVector, PolarVector
>>> cart = Cartesian2DVector.constructor(Quantity([1, 2], "kpc"))
>>> polr = PolarVector(r=Quantity(3, "kpc"), phi=Quantity(90, "deg"))
Expand All @@ -100,7 +100,7 @@ def __sub__(self, other: Any, /) -> "Cartesian2DVector":

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian2DVector, PolarVector
>>> cart = Cartesian2DVector.constructor(Quantity([1, 2], "kpc"))
>>> polr = PolarVector(r=Quantity(3, "kpc"), phi=Quantity(90, "deg"))
Expand All @@ -122,7 +122,7 @@ def norm(self) -> BatchableLength:

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import Cartesian2DVector
>>> q = Cartesian2DVector.constructor(Quantity([3, 4], "kpc"))
>>> q.norm()
Expand Down Expand Up @@ -165,7 +165,7 @@ def norm(self) -> BatchableLength:

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import PolarVector
>>> q = PolarVector(r=Quantity(3, "kpc"), phi=Quantity(90, "deg"))
>>> q.norm()
Expand Down Expand Up @@ -203,7 +203,7 @@ def norm(self, _: Abstract2DVector | None = None, /) -> BatchableSpeed:

Examples
--------
>>> from jax_quantity import Quantity
>>> from unxt import Quantity
>>> from coordinax import CartesianDifferential2D
>>> v = CartesianDifferential2D.constructor(Quantity([3, 4], "km/s"))
>>> v.norm()
Expand Down
6 changes: 3 additions & 3 deletions src/coordinax/_d2/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from plum import conversion_method

import quaxed.array_api as xp
from jax_quantity import Quantity
from unxt import Quantity

from .base import Abstract2DVector
from .builtin import Cartesian2DVector, CartesianDifferential2D
Expand All @@ -19,7 +19,7 @@

@conversion_method(type_from=Abstract2DVector, type_to=Quantity) # type: ignore[misc]
def vec_to_q(obj: Abstract2DVector, /) -> Shaped[Quantity["length"], "*batch 2"]:
"""`coordinax.Abstract2DVector` -> `jax_quantity.Quantity`."""
"""`coordinax.Abstract2DVector` -> `unxt.Quantity`."""
cart = full_shaped(obj.represent_as(Cartesian2DVector))
return xp.stack(tuple(dataclass_values(cart)), axis=-1)

Expand All @@ -28,5 +28,5 @@ def vec_to_q(obj: Abstract2DVector, /) -> Shaped[Quantity["length"], "*batch 2"]
def vec_diff_to_q(
obj: CartesianDifferential2D, /
) -> Shaped[Quantity["speed"], "*batch 2"]:
"""`coordinax.CartesianDifferential2D` -> `jax_quantity.Quantity`."""
"""`coordinax.CartesianDifferential2D` -> `unxt.Quantity`."""
return xp.stack(tuple(dataclass_values(full_shaped(obj))), axis=-1)
Loading
Loading