diff --git a/pyproject.toml b/pyproject.toml index cee52c94..929d1086 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ "Typing :: Typed", ] dependencies = [ - "array_api_jax_compat >= 0.1", + "quaxed >= 0.2", "astropy", "equinox", "jax", @@ -134,7 +134,7 @@ ignore_missing_imports = true module = [ "array_api.*", - "array_api_jax_compat.*", + "quaxed.*", "astropy.*", "equinox.*", "hypothesis.*", @@ -188,7 +188,7 @@ [tool.ruff.lint.isort] combine-as-imports = true - known-first-party = ["array_api_jax_compat", "jax_quantity"] + known-first-party = ["quaxed", "jax_quantity"] known-local-folder = ["coordinax"] diff --git a/src/coordinax/_base.py b/src/coordinax/_base.py index 27ee4692..210fabfc 100644 --- a/src/coordinax/_base.py +++ b/src/coordinax/_base.py @@ -20,7 +20,7 @@ from jaxtyping import ArrayLike from plum import dispatch -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from ._utils import classproperty, dataclass_items, dataclass_values, full_shaped diff --git a/src/coordinax/_checks.py b/src/coordinax/_checks.py index af65c7e5..9cd313aa 100644 --- a/src/coordinax/_checks.py +++ b/src/coordinax/_checks.py @@ -5,7 +5,7 @@ import equinox as eqx -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from coordinax._typing import BatchableAngle, BatchableLength diff --git a/src/coordinax/_d1/builtin.py b/src/coordinax/_d1/builtin.py index d0d79156..cb07672c 100644 --- a/src/coordinax/_d1/builtin.py +++ b/src/coordinax/_d1/builtin.py @@ -16,7 +16,7 @@ import equinox as eqx import jax -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .base import Abstract1DVector, Abstract1DVectorDifferential diff --git a/src/coordinax/_d1/compat.py b/src/coordinax/_d1/compat.py index fe1a79c2..0649c8b9 100644 --- a/src/coordinax/_d1/compat.py +++ b/src/coordinax/_d1/compat.py @@ -6,7 +6,7 @@ from jaxtyping import Shaped from plum import conversion_method -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .base import Abstract1DVector diff --git a/src/coordinax/_d2/builtin.py b/src/coordinax/_d2/builtin.py index 666010d7..7d000de4 100644 --- a/src/coordinax/_d2/builtin.py +++ b/src/coordinax/_d2/builtin.py @@ -16,7 +16,7 @@ import equinox as eqx import jax -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .base import Abstract2DVector, Abstract2DVectorDifferential diff --git a/src/coordinax/_d2/compat.py b/src/coordinax/_d2/compat.py index 5118e023..1e279a7d 100644 --- a/src/coordinax/_d2/compat.py +++ b/src/coordinax/_d2/compat.py @@ -6,7 +6,7 @@ from jaxtyping import Shaped from plum import conversion_method -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .base import Abstract2DVector diff --git a/src/coordinax/_d2/transform.py b/src/coordinax/_d2/transform.py index 529734d2..85b6e977 100644 --- a/src/coordinax/_d2/transform.py +++ b/src/coordinax/_d2/transform.py @@ -6,7 +6,7 @@ from plum import dispatch -import array_api_jax_compat as xp +import quaxed.array_api as xp from .base import Abstract2DVector, Abstract2DVectorDifferential from .builtin import ( diff --git a/src/coordinax/_d3/__init__.py b/src/coordinax/_d3/__init__.py index 9b5f7dbf..f90673bc 100644 --- a/src/coordinax/_d3/__init__.py +++ b/src/coordinax/_d3/__init__.py @@ -1,5 +1,5 @@ # pylint: disable=duplicate-code -"""3-dimensional.""" +"""3-dimensional representations.""" from . import base, builtin, compat, operate, transform from .base import * diff --git a/src/coordinax/_d3/builtin.py b/src/coordinax/_d3/builtin.py index 0c0d5a56..c81e7db7 100644 --- a/src/coordinax/_d3/builtin.py +++ b/src/coordinax/_d3/builtin.py @@ -18,7 +18,7 @@ import equinox as eqx import jax -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .base import Abstract3DVector, Abstract3DVectorDifferential diff --git a/src/coordinax/_d3/compat.py b/src/coordinax/_d3/compat.py index 4c825681..7b7b21bb 100644 --- a/src/coordinax/_d3/compat.py +++ b/src/coordinax/_d3/compat.py @@ -8,7 +8,7 @@ from jaxtyping import Shaped from plum import conversion_method, convert -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .base import Abstract3DVector diff --git a/src/coordinax/_d3/transform.py b/src/coordinax/_d3/transform.py index a5aff246..f70d73ea 100644 --- a/src/coordinax/_d3/transform.py +++ b/src/coordinax/_d3/transform.py @@ -6,7 +6,7 @@ from plum import dispatch -import array_api_jax_compat as xp +import quaxed.array_api as xp from .base import Abstract3DVector, Abstract3DVectorDifferential from .builtin import ( diff --git a/src/coordinax/_d4/compat.py b/src/coordinax/_d4/compat.py index f35d552e..8cdddb5e 100644 --- a/src/coordinax/_d4/compat.py +++ b/src/coordinax/_d4/compat.py @@ -6,7 +6,7 @@ from jaxtyping import Shaped from plum import conversion_method, convert -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .spacetime import FourVector diff --git a/src/coordinax/_d4/spacetime.py b/src/coordinax/_d4/spacetime.py index 3a16a01c..00c45dd8 100644 --- a/src/coordinax/_d4/spacetime.py +++ b/src/coordinax/_d4/spacetime.py @@ -12,7 +12,7 @@ import jax.numpy as jnp from jaxtyping import Shaped -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .base import Abstract4DVector diff --git a/src/coordinax/_transform.py b/src/coordinax/_transform.py index 5c6acb94..262ba9d5 100644 --- a/src/coordinax/_transform.py +++ b/src/coordinax/_transform.py @@ -1,4 +1,4 @@ -"""Representation of coordinates in different systems.""" +"""Transformations between representations.""" __all__ = ["represent_as"] @@ -10,7 +10,7 @@ import jax from plum import dispatch -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from ._base import AbstractVector, AbstractVectorDifferential diff --git a/src/coordinax/_utils.py b/src/coordinax/_utils.py index f8bdf147..90c9a3ce 100644 --- a/src/coordinax/_utils.py +++ b/src/coordinax/_utils.py @@ -14,7 +14,7 @@ runtime_checkable, ) -import array_api_jax_compat as xp +import quaxed.array_api as xp if TYPE_CHECKING: from coordinax._base import AbstractVectorBase diff --git a/src/coordinax/operators/_galilean/boost.py b/src/coordinax/operators/_galilean/boost.py index 2760ac1a..53af73c0 100644 --- a/src/coordinax/operators/_galilean/boost.py +++ b/src/coordinax/operators/_galilean/boost.py @@ -11,7 +11,7 @@ import jax.numpy as jnp from plum import convert -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .base import AbstractGalileanOperator @@ -47,7 +47,7 @@ class GalileanBoostOperator(AbstractGalileanOperator): -------- We start with the required imports: - >>> import array_api_jax_compat as xp + >>> import quaxed.array_api as xp >>> from jax_quantity import Quantity >>> from coordinax import CartesianDifferential3D, Cartesian3DVector >>> import coordinax.operators as co diff --git a/src/coordinax/operators/_galilean/composite.py b/src/coordinax/operators/_galilean/composite.py index 575d0e48..2bd58e48 100644 --- a/src/coordinax/operators/_galilean/composite.py +++ b/src/coordinax/operators/_galilean/composite.py @@ -8,7 +8,7 @@ import equinox as eqx -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .base import AbstractGalileanOperator diff --git a/src/coordinax/operators/_galilean/rotation.py b/src/coordinax/operators/_galilean/rotation.py index bded37e1..269c5a2f 100644 --- a/src/coordinax/operators/_galilean/rotation.py +++ b/src/coordinax/operators/_galilean/rotation.py @@ -14,7 +14,7 @@ from plum import convert from quax import quaxify -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .base import AbstractGalileanOperator @@ -145,7 +145,7 @@ def is_inertial(self) -> Literal[True]: Examples -------- - >>> import array_api_jax_compat as xp + >>> import quaxed.array_api as xp >>> from jax_quantity import Quantity >>> from coordinax.operators import GalileanRotationOperator @@ -166,7 +166,7 @@ def inverse(self) -> "GalileanRotationOperator": Examples -------- - >>> import array_api_jax_compat as xp + >>> import quaxed.array_api as xp >>> from jax_quantity import Quantity >>> from coordinax.operators import GalileanRotationOperator @@ -194,7 +194,7 @@ def __call__( Examples -------- - >>> import array_api_jax_compat as xp + >>> import quaxed.array_api as xp >>> from jax_quantity import Quantity >>> from coordinax import Cartesian3DVector, CartesianDifferential3D >>> from coordinax.operators import GalileanRotationOperator @@ -226,7 +226,7 @@ def __call__( Examples -------- - >>> import array_api_jax_compat as xp + >>> import quaxed.array_api as xp >>> from jax_quantity import Quantity >>> from coordinax import Cartesian3DVector, CartesianDifferential3D >>> from coordinax.operators import GalileanRotationOperator diff --git a/src/coordinax/operators/_galilean/translation.py b/src/coordinax/operators/_galilean/translation.py index 33ae5c56..9f15c72c 100644 --- a/src/coordinax/operators/_galilean/translation.py +++ b/src/coordinax/operators/_galilean/translation.py @@ -11,7 +11,7 @@ import jax.numpy as jnp from plum import convert -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .base import AbstractGalileanOperator @@ -79,7 +79,7 @@ class GalileanSpatialTranslationOperator(AbstractGalileanOperator): -------- We start with the required imports: - >>> import array_api_jax_compat as xp + >>> import quaxed.array_api as xp >>> from jax_quantity import Quantity >>> import coordinax as cx @@ -324,7 +324,7 @@ class GalileanTranslationOperator(AbstractGalileanOperator): -------- We start with the required imports: - >>> import array_api_jax_compat as xp + >>> import quaxed.array_api as xp >>> from jax_quantity import Quantity >>> import coordinax.operators as co diff --git a/tests/test_base.py b/tests/test_base.py index 2f319583..279632ec 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -10,7 +10,7 @@ import pytest from quax import quaxify -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from coordinax import ( diff --git a/tests/test_d2.py b/tests/test_d2.py index 602c4d3f..9aca6fa1 100644 --- a/tests/test_d2.py +++ b/tests/test_d2.py @@ -5,7 +5,7 @@ import pytest from quax import quaxify -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .test_base import AbstractVectorDifferentialTest, AbstractVectorTest, array_equal diff --git a/tests/test_d3.py b/tests/test_d3.py index ac9953a3..d3647928 100644 --- a/tests/test_d3.py +++ b/tests/test_d3.py @@ -8,7 +8,7 @@ from astropy.coordinates.tests.test_representation import representation_equal from plum import convert -import array_api_jax_compat as xp +import quaxed.array_api as xp from jax_quantity import Quantity from .test_base import AbstractVectorDifferentialTest, AbstractVectorTest, array_equal