Skip to content

Commit

Permalink
test: jit rep_as (#51)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Mar 3, 2024
1 parent 2c96275 commit 6c98787
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"jax",
"jaxlib",
"jaxtyping",
"jax_quantity >= 0.2.1",
"jax_quantity @ git+https://github.com/GalacticDynamics/jax-quantity.git",
"quax>=0.0.3",
]
description = "Vectors in JAX"
Expand Down
26 changes: 21 additions & 5 deletions src/coordinax/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import Callable, Mapping
from dataclasses import fields, replace
from functools import partial
from inspect import isabstract
from types import MappingProxyType
from typing import TYPE_CHECKING, Any, Literal, TypeVar

Expand All @@ -32,11 +33,8 @@
VT = TypeVar("VT", bound="AbstractVector")
DT = TypeVar("DT", bound="AbstractVectorDifferential")


_0m = Quantity(0, "meter")
_0d = Quantity(0, "rad")
_pid = Quantity(xp.pi, "rad")
_2pid = Quantity(2 * xp.pi, "rad")
VECTOR_CLASSES: list[type["AbstractVector"]] = []
DIFFERENTIAL_CLASSES: list[type["AbstractVectorDifferential"]] = []


class AbstractVectorBase(eqx.Module): # type: ignore[misc]
Expand Down Expand Up @@ -678,6 +676,17 @@ def constructor( # noqa: D417
class AbstractVector(AbstractVectorBase): # pylint: disable=abstract-method
"""Abstract representation of coordinates in different systems."""

def __init_subclass__(cls, **kwargs: Any) -> None:
"""Initialize the subclass.
The subclass is registered if it is not an abstract class.
"""
# TODO: a more robust check using equinox.
if isabstract(cls) or cls.__name__.startswith("Abstract"):
return

VECTOR_CLASSES.append(cls)

@classproperty
@classmethod
@abstractmethod
Expand Down Expand Up @@ -855,6 +864,13 @@ def norm(self) -> Quantity["length"]:
class AbstractVectorDifferential(AbstractVectorBase): # pylint: disable=abstract-method
"""Abstract representation of vector differentials in different systems."""

def __init_subclass__(cls, **kwargs: Any) -> None:
"""Initialize the subclass.
The subclass is registered.
"""
DIFFERENTIAL_CLASSES.append(cls)

@classproperty
@classmethod
@abstractmethod
Expand Down
3 changes: 3 additions & 0 deletions src/coordinax/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from coordinax._base import AbstractVectorBase


################################################################################


def dataclass_values(obj: "DataclassInstance") -> Iterator[Any]:
"""Return the values of a dataclass instance."""
yield from (getattr(obj, f.name) for f in fields(obj))
Expand Down
40 changes: 40 additions & 0 deletions tests/test_jax_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Test using Jax operations."""

from functools import partial

import astropy.units as u
import jax
import pytest

from jax_quantity import Quantity

import coordinax as cx
from coordinax._base import VECTOR_CLASSES
from coordinax._utils import dataclass_items

VECTOR_CLASSES_3D = [c for c in VECTOR_CLASSES if issubclass(c, cx.Abstract3DVector)]


# TODO: cycle through all representations
@pytest.fixture(params=VECTOR_CLASSES_3D)
def q(request) -> cx.AbstractVector:
"""Fixture for 3D Vectors."""
q = cx.Cartesian3DVector.constructor(Quantity([1, 2, 3], unit=u.kpc))
return q.represent_as(request.param)


@partial(jax.jit, static_argnums=(1,))
def func(q: cx.AbstractVector, target: type[cx.AbstractVector]) -> cx.AbstractVector:
return q.represent_as(target)


@pytest.mark.parametrize("target", VECTOR_CLASSES_3D)
def test_jax_through_representation(
q: cx.AbstractVector, target: type[cx.AbstractVector]
) -> None:
"""Test using Jax operations through representation."""
newq = func(q, target)

assert isinstance(newq, cx.AbstractVector)
for k, f in dataclass_items(newq):
assert isinstance(f, Quantity), k

0 comments on commit 6c98787

Please sign in to comment.