Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dep quaxed #57

Merged
merged 4 commits into from
Mar 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
"Typing :: Typed",
]
dependencies = [
"array_api_jax_compat >= 0.1",
"quaxed >= 0.2",
"astropy",
"equinox",
"jax",
Expand Down Expand Up @@ -134,7 +134,7 @@
ignore_missing_imports = true
module = [
"array_api.*",
"array_api_jax_compat.*",
"quaxed.*",
"astropy.*",
"equinox.*",
"hypothesis.*",
Expand Down Expand Up @@ -188,7 +188,7 @@

[tool.ruff.lint.isort]
combine-as-imports = true
known-first-party = ["array_api_jax_compat", "jax_quantity"]
known-first-party = ["quaxed", "jax_quantity"]
known-local-folder = ["coordinax"]


Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from jaxtyping import ArrayLike
from plum import dispatch

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from ._utils import classproperty, dataclass_items, dataclass_values, full_shaped
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import equinox as eqx

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from coordinax._typing import BatchableAngle, BatchableLength
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_d1/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import equinox as eqx
import jax

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .base import Abstract1DVector, Abstract1DVectorDifferential
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_d1/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jaxtyping import Shaped
from plum import conversion_method

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .base import Abstract1DVector
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_d2/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import equinox as eqx
import jax

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .base import Abstract2DVector, Abstract2DVectorDifferential
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_d2/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jaxtyping import Shaped
from plum import conversion_method

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .base import Abstract2DVector
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_d2/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from plum import dispatch

import array_api_jax_compat as xp
import quaxed.array_api as xp

from .base import Abstract2DVector, Abstract2DVectorDifferential
from .builtin import (
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_d3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# pylint: disable=duplicate-code
"""3-dimensional."""
"""3-dimensional representations."""

from . import base, builtin, compat, operate, transform
from .base import *
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_d3/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import equinox as eqx
import jax

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .base import Abstract3DVector, Abstract3DVectorDifferential
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_d3/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from jaxtyping import Shaped
from plum import conversion_method, convert

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .base import Abstract3DVector
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_d3/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from plum import dispatch

import array_api_jax_compat as xp
import quaxed.array_api as xp

from .base import Abstract3DVector, Abstract3DVectorDifferential
from .builtin import (
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_d4/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from jaxtyping import Shaped
from plum import conversion_method, convert

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .spacetime import FourVector
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_d4/spacetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import jax.numpy as jnp
from jaxtyping import Shaped

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .base import Abstract4DVector
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/_transform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Representation of coordinates in different systems."""
"""Transformations between representations."""

__all__ = ["represent_as"]

Expand All @@ -10,7 +10,7 @@
import jax
from plum import dispatch

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from ._base import AbstractVector, AbstractVectorDifferential
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
runtime_checkable,
)

import array_api_jax_compat as xp
import quaxed.array_api as xp

if TYPE_CHECKING:
from coordinax._base import AbstractVectorBase
Expand Down
4 changes: 2 additions & 2 deletions src/coordinax/operators/_galilean/boost.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jax.numpy as jnp
from plum import convert

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .base import AbstractGalileanOperator
Expand Down Expand Up @@ -47,7 +47,7 @@ class GalileanBoostOperator(AbstractGalileanOperator):
--------
We start with the required imports:

>>> import array_api_jax_compat as xp
>>> import quaxed.array_api as xp
>>> from jax_quantity import Quantity
>>> from coordinax import CartesianDifferential3D, Cartesian3DVector
>>> import coordinax.operators as co
Expand Down
2 changes: 1 addition & 1 deletion src/coordinax/operators/_galilean/composite.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import equinox as eqx

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .base import AbstractGalileanOperator
Expand Down
10 changes: 5 additions & 5 deletions src/coordinax/operators/_galilean/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from plum import convert
from quax import quaxify

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .base import AbstractGalileanOperator
Expand Down Expand Up @@ -145,7 +145,7 @@ def is_inertial(self) -> Literal[True]:

Examples
--------
>>> import array_api_jax_compat as xp
>>> import quaxed.array_api as xp
>>> from jax_quantity import Quantity
>>> from coordinax.operators import GalileanRotationOperator

Expand All @@ -166,7 +166,7 @@ def inverse(self) -> "GalileanRotationOperator":

Examples
--------
>>> import array_api_jax_compat as xp
>>> import quaxed.array_api as xp
>>> from jax_quantity import Quantity
>>> from coordinax.operators import GalileanRotationOperator

Expand Down Expand Up @@ -194,7 +194,7 @@ def __call__(

Examples
--------
>>> import array_api_jax_compat as xp
>>> import quaxed.array_api as xp
>>> from jax_quantity import Quantity
>>> from coordinax import Cartesian3DVector, CartesianDifferential3D
>>> from coordinax.operators import GalileanRotationOperator
Expand Down Expand Up @@ -226,7 +226,7 @@ def __call__(

Examples
--------
>>> import array_api_jax_compat as xp
>>> import quaxed.array_api as xp
>>> from jax_quantity import Quantity
>>> from coordinax import Cartesian3DVector, CartesianDifferential3D
>>> from coordinax.operators import GalileanRotationOperator
Expand Down
6 changes: 3 additions & 3 deletions src/coordinax/operators/_galilean/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jax.numpy as jnp
from plum import convert

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .base import AbstractGalileanOperator
Expand Down Expand Up @@ -79,7 +79,7 @@ class GalileanSpatialTranslationOperator(AbstractGalileanOperator):
--------
We start with the required imports:

>>> import array_api_jax_compat as xp
>>> import quaxed.array_api as xp
>>> from jax_quantity import Quantity
>>> import coordinax as cx

Expand Down Expand Up @@ -324,7 +324,7 @@ class GalileanTranslationOperator(AbstractGalileanOperator):
--------
We start with the required imports:

>>> import array_api_jax_compat as xp
>>> import quaxed.array_api as xp
>>> from jax_quantity import Quantity
>>> import coordinax.operators as co

Expand Down
2 changes: 1 addition & 1 deletion tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import pytest
from quax import quaxify

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from coordinax import (
Expand Down
2 changes: 1 addition & 1 deletion tests/test_d2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest
from quax import quaxify

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .test_base import AbstractVectorDifferentialTest, AbstractVectorTest, array_equal
Expand Down
2 changes: 1 addition & 1 deletion tests/test_d3.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from astropy.coordinates.tests.test_representation import representation_equal
from plum import convert

import array_api_jax_compat as xp
import quaxed.array_api as xp
from jax_quantity import Quantity

from .test_base import AbstractVectorDifferentialTest, AbstractVectorTest, array_equal
Expand Down