Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: convert to quantity #26

Merged
merged 5 commits into from
Feb 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor: add some shape information to conversion
Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman committed Feb 24, 2024
commit 84ac4b5f40d808f1655d4cc91f350a075a0e1064
7 changes: 5 additions & 2 deletions src/vector/_d1/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import array_api_jax_compat as xp
from jax_quantity import Quantity
from jaxtyping import Shaped
from plum import conversion_method

from vector._utils import dataclass_values, full_shaped
Expand All @@ -17,13 +18,15 @@


@conversion_method(type_from=Abstract1DVector, type_to=Quantity) # type: ignore[misc]
def vec_to_q(obj: Abstract1DVector, /) -> Quantity["length"]:
def vec_to_q(obj: Abstract1DVector, /) -> Shaped[Quantity["length"], "*batch 1"]:
"""`vector.Abstract1DVector` -> `jax_quantity.Quantity`."""
cart = full_shaped(obj.represent_as(Cartesian1DVector))
return xp.stack(tuple(dataclass_values(cart)), axis=-1)

Check warning on line 24 in src/vector/_d1/compat.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d1/compat.py#L23-L24

Added lines #L23 - L24 were not covered by tests


@conversion_method(type_from=CartesianDifferential1D, type_to=Quantity) # type: ignore[misc]
def vec_diff_to_q(obj: CartesianDifferential1D, /) -> Quantity["speed"]:
def vec_diff_to_q(
obj: CartesianDifferential1D, /
) -> Shaped[Quantity["speed"], "*batch 1"]:
"""`vector.CartesianDifferential1D` -> `jax_quantity.Quantity`."""
return xp.stack(tuple(dataclass_values(full_shaped(obj))), axis=-1)

Check warning on line 32 in src/vector/_d1/compat.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d1/compat.py#L32

Added line #L32 was not covered by tests
7 changes: 5 additions & 2 deletions src/vector/_d2/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import array_api_jax_compat as xp
from jax_quantity import Quantity
from jaxtyping import Shaped
from plum import conversion_method

from vector._utils import dataclass_values, full_shaped
Expand All @@ -17,13 +18,15 @@


@conversion_method(type_from=Abstract2DVector, type_to=Quantity) # type: ignore[misc]
def vec_to_q(obj: Abstract2DVector, /) -> Quantity["length"]:
def vec_to_q(obj: Abstract2DVector, /) -> Shaped[Quantity["length"], "*batch 2"]:
"""`vector.Abstract2DVector` -> `jax_quantity.Quantity`."""
cart = full_shaped(obj.represent_as(Cartesian2DVector))
return xp.stack(tuple(dataclass_values(cart)), axis=-1)

Check warning on line 24 in src/vector/_d2/compat.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d2/compat.py#L23-L24

Added lines #L23 - L24 were not covered by tests


@conversion_method(type_from=CartesianDifferential2D, type_to=Quantity) # type: ignore[misc]
def vec_diff_to_q(obj: CartesianDifferential2D, /) -> Quantity["speed"]:
def vec_diff_to_q(
obj: CartesianDifferential2D, /
) -> Shaped[Quantity["speed"], "*batch 2"]:
"""`vector.CartesianDifferential2D` -> `jax_quantity.Quantity`."""
return xp.stack(tuple(dataclass_values(full_shaped(obj))), axis=-1)

Check warning on line 32 in src/vector/_d2/compat.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d2/compat.py#L32

Added line #L32 was not covered by tests
7 changes: 5 additions & 2 deletions src/vector/_d3/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import astropy.coordinates as apyc
import astropy.units as apyu
from jax_quantity import Quantity
from jaxtyping import Shaped
from plum import conversion_method

from vector._utils import dataclass_values, full_shaped
Expand All @@ -26,16 +27,18 @@


@conversion_method(type_from=Abstract3DVector, type_to=Quantity) # type: ignore[misc]
def vec_to_q(obj: Abstract3DVector, /) -> Quantity["length"]:
def vec_to_q(obj: Abstract3DVector, /) -> Shaped[Quantity["length"], "*batch 3"]:
"""`vector.Abstract3DVector` -> `jax_quantity.Quantity`."""
cart = full_shaped(obj.represent_as(Cartesian3DVector))
return xp.stack(tuple(dataclass_values(cart)), axis=-1)

Check warning on line 33 in src/vector/_d3/compat.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d3/compat.py#L32-L33

Added lines #L32 - L33 were not covered by tests


@conversion_method(type_from=CartesianDifferential3D, type_to=Quantity) # type: ignore[misc]
def vec_diff_to_q(obj: CartesianDifferential3D, /) -> Quantity["speed"]:
def vec_diff_to_q(
obj: CartesianDifferential3D, /
) -> Shaped[Quantity["speed"], "*batch 3"]:
"""`vector.CartesianDifferential3D` -> `jax_quantity.Quantity`."""
return xp.stack(tuple(dataclass_values(full_shaped(obj))), axis=-1)

Check warning on line 41 in src/vector/_d3/compat.py

View check run for this annotation

Codecov / codecov/patch

src/vector/_d3/compat.py#L41

Added line #L41 was not covered by tests


#####################################################################
Expand Down
Loading