Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: additional spherical representations #80

Merged
merged 8 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"--showlocals",
"--strict-markers",
"--strict-config",
"--doctest-glob='*.rst | *.py'",
]
filterwarnings = [
"error",
Expand Down Expand Up @@ -208,6 +209,7 @@
"missing-function-docstring", # TODO: resolve
"missing-module-docstring",
"no-member", # handled by mypy
"no-value-for-parameter", # pylint doesn't understand multiple dispatch
"not-a-mapping", # pylint doesn't understand dataclass fields
"protected-access", # handled by ruff
"redefined-builtin", # handled by ruff
Expand Down
8 changes: 4 additions & 4 deletions src/coordinax/_base_dif.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def integral_cls(cls) -> type["AbstractVectorDifferential"]:
--------
>>> from coordinax import RadialDifferential, SphericalDifferential

>>> RadialDifferential.integral_cls
<class 'coordinax._d1.builtin.RadialVector'>
>>> RadialDifferential.integral_cls.__name__
'RadialVector'

>>> SphericalDifferential.integral_cls
<class 'coordinax._d3.builtin.SphericalVector'>
>>> SphericalDifferential.integral_cls.__name__
'SphericalVector'

"""
raise NotImplementedError
Expand Down
8 changes: 4 additions & 4 deletions src/coordinax/_base_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def differential_cls(cls) -> type["AbstractVectorDifferential"]:
--------
>>> from coordinax import RadialVector, SphericalVector

>>> RadialVector.differential_cls
<class 'coordinax._d1.builtin.RadialDifferential'>
>>> RadialVector.differential_cls.__name__
'RadialDifferential'

>>> SphericalVector.differential_cls
<class 'coordinax._d3.builtin.SphericalDifferential'>
>>> SphericalVector.differential_cls.__name__
'SphericalDifferential'

"""
raise NotImplementedError
Expand Down
78 changes: 0 additions & 78 deletions src/coordinax/_d2/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ def represent_as(
@dispatch.multi(
(Cartesian2DVector, type[Cartesian2DVector]),
(PolarVector, type[PolarVector]),
# (LnPolarVector, type[LnPolarVector]),
# (Log10PolarVector, type[Log10PolarVector]),
)
def represent_as(
current: Abstract2DVector, target: type[Abstract2DVector], /, **kwargs: Any
Expand All @@ -57,20 +55,6 @@ def represent_as(
return current


# @dispatch.multi(
# (Cartesian2DVector, type[LnPolarVector]),
# (Cartesian2DVector, type[Log10PolarVector]),
# (LnPolarVector, type[Cartesian2DVector]),
# (Log10PolarVector, type[Cartesian2DVector]),
# )
# def represent_as(
# current: Abstract2DVector, target: type[Abstract2DVector], /, **kwargs: Any
# ) -> Abstract2DVector:
# """Abstract2DVector -> PolarVector -> Abstract2DVector."""
# polar = represent_as(current, PolarVector)
# return represent_as(polar, target)


# =============================================================================
# Cartesian2DVector

Expand Down Expand Up @@ -107,65 +91,3 @@ def represent_as(
x = current.r * xp.cos(current.phi)
y = current.r * xp.sin(current.phi)
return target(x=x, y=y)


# @dispatch
# def represent_as(
# current: PolarVector, target: type[LnPolarVector], /, **kwargs: Any
# ) -> LnPolarVector:
# """PolarVector -> LnPolarVector."""
# return target(lnr=xp.log(current.r), phi=current.phi)


# @dispatch
# def represent_as(
# current: PolarVector, target: type[Log10PolarVector], /, **kwargs: Any
# ) -> Log10PolarVector:
# """PolarVector -> Log10PolarVector."""
# return target(log10r=xp.log10(current.r), phi=current.phi)


# # =============================================================================
# # LnPolarVector

# # -----------------------------------------------
# # 2D


# @dispatch
# def represent_as(
# current: LnPolarVector, target: type[PolarVector], /, **kwargs: Any
# ) -> PolarVector:
# """LnPolarVector -> PolarVector."""
# return target(r=xp.exp(current.lnr), phi=current.phi)


# @dispatch
# def represent_as(
# current: LnPolarVector, target: type[Log10PolarVector], /, **kwargs: Any
# ) -> Log10PolarVector:
# """LnPolarVector -> Log10PolarVector."""
# return target(log10r=current.lnr / xp.log(10), phi=current.phi)


# # =============================================================================
# # Log10PolarVector

# # -----------------------------------------------
# # 2D


# @dispatch
# def represent_as(
# current: Log10PolarVector, target: type[PolarVector], /, **kwargs: Any
# ) -> PolarVector:
# """Log10PolarVector -> PolarVector."""
# return target(r=xp.pow(10, current.log10r), phi=current.phi)


# @dispatch
# def represent_as(
# current: Log10PolarVector, target: type[LnPolarVector], /, **kwargs: Any
# ) -> LnPolarVector:
# """Log10PolarVector -> LnPolarVector."""
# return target(lnr=current.log10r * xp.log(10), phi=current.phi)
4 changes: 3 additions & 1 deletion src/coordinax/_d3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# pylint: disable=duplicate-code
"""3-dimensional representations."""

from . import base, builtin, compat, operate, transform
from . import base, builtin, compat, operate, sphere, transform
from .base import *
from .builtin import *
from .compat import *
from .operate import *
from .sphere import *
from .transform import *

__all__: list[str] = []
__all__ += base.__all__
__all__ += builtin.__all__
__all__ += sphere.__all__
__all__ += transform.__all__
__all__ += operate.__all__
__all__ += compat.__all__
80 changes: 2 additions & 78 deletions src/coordinax/_d3/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
__all__ = [
# Position
"Cartesian3DVector",
"SphericalVector",
"CylindricalVector",
# Differential
"CartesianDifferential3D",
"SphericalDifferential",
"CylindricalDifferential",
]

Expand All @@ -19,12 +17,12 @@
import jax

import quaxed.array_api as xp
from unxt import Distance, Quantity
from unxt import Quantity

import coordinax._typing as ct
from .base import Abstract3DVector, Abstract3DVectorDifferential
from coordinax._base_vec import AbstractVector
from coordinax._checks import check_phi_range, check_r_non_negative, check_theta_range
from coordinax._checks import check_phi_range, check_r_non_negative
from coordinax._converters import converter_phi_to_range
from coordinax._utils import classproperty

Expand Down Expand Up @@ -134,55 +132,6 @@ def norm(self) -> ct.BatchableLength:
return xp.sqrt(self.x**2 + self.y**2 + self.z**2)


@final
class SphericalVector(Abstract3DVector):
"""Spherical vector representation."""

r: ct.BatchableDistance = eqx.field(
converter=partial(Distance.constructor, dtype=float)
)
r"""Radial distance :math:`r \in [0,+\infty)`."""

phi: ct.BatchableAngle = eqx.field(
converter=lambda x: converter_phi_to_range(
Quantity["angle"].constructor(x, dtype=float) # pylint: disable=E1120
)
)
r"""Azimuthal angle :math:`\phi \in [0,360)`."""

theta: ct.BatchableAngle = eqx.field(
converter=partial(Quantity["angle"].constructor, dtype=float)
)
r"""Inclination angle :math:`\phi \in [0,180]`."""

def __check_init__(self) -> None:
"""Check the validity of the initialisation."""
check_r_non_negative(self.r)
check_theta_range(self.theta)
check_phi_range(self.phi)

@classproperty
@classmethod
def differential_cls(cls) -> type["SphericalDifferential"]:
return SphericalDifferential

@partial(jax.jit)
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.

Examples
--------
>>> from unxt import Quantity
>>> from coordinax import SphericalVector
>>> s = SphericalVector(r=Quantity(3, "kpc"), theta=Quantity(90, "deg"),
... phi=Quantity(0, "deg"))
>>> s.norm()
Distance(Array(3., dtype=float32), unit='kpc')

"""
return self.r


@final
class CylindricalVector(Abstract3DVector):
"""Cylindrical vector representation."""
Expand Down Expand Up @@ -277,31 +226,6 @@ def norm(self, _: Abstract3DVector | None = None, /) -> ct.BatchableSpeed:
return xp.sqrt(self.d_x**2 + self.d_y**2 + self.d_z**2)


@final
class SphericalDifferential(Abstract3DVectorDifferential):
"""Spherical differential representation."""

d_r: ct.BatchableSpeed = eqx.field(
converter=partial(Quantity["speed"].constructor, dtype=float)
)
r"""Radial speed :math:`dr/dt \in [-\infty, \infty]."""

d_theta: ct.BatchableAngularSpeed = eqx.field(
converter=partial(Quantity["angular speed"].constructor, dtype=float)
)
r"""Inclination speed :math:`d\theta/dt \in [-\infty, \infty]."""

d_phi: ct.BatchableAngularSpeed = eqx.field(
converter=partial(Quantity["angular speed"].constructor, dtype=float)
)
r"""Azimuthal speed :math:`d\phi/dt \in [-\infty, \infty]."""

@classproperty
@classmethod
def integral_cls(cls) -> type[SphericalVector]:
return SphericalVector


@final
class CylindricalDifferential(Abstract3DVectorDifferential):
"""Cylindrical differential representation."""
Expand Down
Loading
Loading