Skip to content

Commit

Permalink
feat: additional spherical representations (#80)
Browse files Browse the repository at this point in the history
* feat: mathspherical
* fix: glob collect doctests
* docs: tests
* feat: lonlat
* feat: loncoslat
* refactor: transforms as module

Signed-off-by: nstarman <[email protected]>
  • Loading branch information
nstarman authored Mar 30, 2024
1 parent 0165f9a commit 140380f
Show file tree
Hide file tree
Showing 19 changed files with 3,002 additions and 1,419 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
"--showlocals",
"--strict-markers",
"--strict-config",
"--doctest-glob='*.rst | *.py'",
]
filterwarnings = [
"error",
Expand Down Expand Up @@ -208,6 +209,7 @@
"missing-function-docstring", # TODO: resolve
"missing-module-docstring",
"no-member", # handled by mypy
"no-value-for-parameter", # pylint doesn't understand multiple dispatch
"not-a-mapping", # pylint doesn't understand dataclass fields
"protected-access", # handled by ruff
"redefined-builtin", # handled by ruff
Expand Down
8 changes: 4 additions & 4 deletions src/coordinax/_base_dif.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ def integral_cls(cls) -> type["AbstractVectorDifferential"]:
--------
>>> from coordinax import RadialDifferential, SphericalDifferential
>>> RadialDifferential.integral_cls
<class 'coordinax._d1.builtin.RadialVector'>
>>> RadialDifferential.integral_cls.__name__
'RadialVector'
>>> SphericalDifferential.integral_cls
<class 'coordinax._d3.builtin.SphericalVector'>
>>> SphericalDifferential.integral_cls.__name__
'SphericalVector'
"""
raise NotImplementedError
Expand Down
8 changes: 4 additions & 4 deletions src/coordinax/_base_vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ def differential_cls(cls) -> type["AbstractVectorDifferential"]:
--------
>>> from coordinax import RadialVector, SphericalVector
>>> RadialVector.differential_cls
<class 'coordinax._d1.builtin.RadialDifferential'>
>>> RadialVector.differential_cls.__name__
'RadialDifferential'
>>> SphericalVector.differential_cls
<class 'coordinax._d3.builtin.SphericalDifferential'>
>>> SphericalVector.differential_cls.__name__
'SphericalDifferential'
"""
raise NotImplementedError
Expand Down
78 changes: 0 additions & 78 deletions src/coordinax/_d2/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@ def represent_as(
@dispatch.multi(
(Cartesian2DVector, type[Cartesian2DVector]),
(PolarVector, type[PolarVector]),
# (LnPolarVector, type[LnPolarVector]),
# (Log10PolarVector, type[Log10PolarVector]),
)
def represent_as(
current: Abstract2DVector, target: type[Abstract2DVector], /, **kwargs: Any
Expand All @@ -57,20 +55,6 @@ def represent_as(
return current


# @dispatch.multi(
# (Cartesian2DVector, type[LnPolarVector]),
# (Cartesian2DVector, type[Log10PolarVector]),
# (LnPolarVector, type[Cartesian2DVector]),
# (Log10PolarVector, type[Cartesian2DVector]),
# )
# def represent_as(
# current: Abstract2DVector, target: type[Abstract2DVector], /, **kwargs: Any
# ) -> Abstract2DVector:
# """Abstract2DVector -> PolarVector -> Abstract2DVector."""
# polar = represent_as(current, PolarVector)
# return represent_as(polar, target)


# =============================================================================
# Cartesian2DVector

Expand Down Expand Up @@ -107,65 +91,3 @@ def represent_as(
x = current.r * xp.cos(current.phi)
y = current.r * xp.sin(current.phi)
return target(x=x, y=y)


# @dispatch
# def represent_as(
# current: PolarVector, target: type[LnPolarVector], /, **kwargs: Any
# ) -> LnPolarVector:
# """PolarVector -> LnPolarVector."""
# return target(lnr=xp.log(current.r), phi=current.phi)


# @dispatch
# def represent_as(
# current: PolarVector, target: type[Log10PolarVector], /, **kwargs: Any
# ) -> Log10PolarVector:
# """PolarVector -> Log10PolarVector."""
# return target(log10r=xp.log10(current.r), phi=current.phi)


# # =============================================================================
# # LnPolarVector

# # -----------------------------------------------
# # 2D


# @dispatch
# def represent_as(
# current: LnPolarVector, target: type[PolarVector], /, **kwargs: Any
# ) -> PolarVector:
# """LnPolarVector -> PolarVector."""
# return target(r=xp.exp(current.lnr), phi=current.phi)


# @dispatch
# def represent_as(
# current: LnPolarVector, target: type[Log10PolarVector], /, **kwargs: Any
# ) -> Log10PolarVector:
# """LnPolarVector -> Log10PolarVector."""
# return target(log10r=current.lnr / xp.log(10), phi=current.phi)


# # =============================================================================
# # Log10PolarVector

# # -----------------------------------------------
# # 2D


# @dispatch
# def represent_as(
# current: Log10PolarVector, target: type[PolarVector], /, **kwargs: Any
# ) -> PolarVector:
# """Log10PolarVector -> PolarVector."""
# return target(r=xp.pow(10, current.log10r), phi=current.phi)


# @dispatch
# def represent_as(
# current: Log10PolarVector, target: type[LnPolarVector], /, **kwargs: Any
# ) -> LnPolarVector:
# """Log10PolarVector -> LnPolarVector."""
# return target(lnr=current.log10r * xp.log(10), phi=current.phi)
4 changes: 3 additions & 1 deletion src/coordinax/_d3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
# pylint: disable=duplicate-code
"""3-dimensional representations."""

from . import base, builtin, compat, operate, transform
from . import base, builtin, compat, operate, sphere, transform
from .base import *
from .builtin import *
from .compat import *
from .operate import *
from .sphere import *
from .transform import *

__all__: list[str] = []
__all__ += base.__all__
__all__ += builtin.__all__
__all__ += sphere.__all__
__all__ += transform.__all__
__all__ += operate.__all__
__all__ += compat.__all__
80 changes: 2 additions & 78 deletions src/coordinax/_d3/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
__all__ = [
# Position
"Cartesian3DVector",
"SphericalVector",
"CylindricalVector",
# Differential
"CartesianDifferential3D",
"SphericalDifferential",
"CylindricalDifferential",
]

Expand All @@ -19,12 +17,12 @@
import jax

import quaxed.array_api as xp
from unxt import Distance, Quantity
from unxt import Quantity

import coordinax._typing as ct
from .base import Abstract3DVector, Abstract3DVectorDifferential
from coordinax._base_vec import AbstractVector
from coordinax._checks import check_phi_range, check_r_non_negative, check_theta_range
from coordinax._checks import check_phi_range, check_r_non_negative
from coordinax._converters import converter_phi_to_range
from coordinax._utils import classproperty

Expand Down Expand Up @@ -134,55 +132,6 @@ def norm(self) -> ct.BatchableLength:
return xp.sqrt(self.x**2 + self.y**2 + self.z**2)


@final
class SphericalVector(Abstract3DVector):
"""Spherical vector representation."""

r: ct.BatchableDistance = eqx.field(
converter=partial(Distance.constructor, dtype=float)
)
r"""Radial distance :math:`r \in [0,+\infty)`."""

phi: ct.BatchableAngle = eqx.field(
converter=lambda x: converter_phi_to_range(
Quantity["angle"].constructor(x, dtype=float) # pylint: disable=E1120
)
)
r"""Azimuthal angle :math:`\phi \in [0,360)`."""

theta: ct.BatchableAngle = eqx.field(
converter=partial(Quantity["angle"].constructor, dtype=float)
)
r"""Inclination angle :math:`\phi \in [0,180]`."""

def __check_init__(self) -> None:
"""Check the validity of the initialisation."""
check_r_non_negative(self.r)
check_theta_range(self.theta)
check_phi_range(self.phi)

@classproperty
@classmethod
def differential_cls(cls) -> type["SphericalDifferential"]:
return SphericalDifferential

@partial(jax.jit)
def norm(self) -> ct.BatchableLength:
"""Return the norm of the vector.
Examples
--------
>>> from unxt import Quantity
>>> from coordinax import SphericalVector
>>> s = SphericalVector(r=Quantity(3, "kpc"), theta=Quantity(90, "deg"),
... phi=Quantity(0, "deg"))
>>> s.norm()
Distance(Array(3., dtype=float32), unit='kpc')
"""
return self.r


@final
class CylindricalVector(Abstract3DVector):
"""Cylindrical vector representation."""
Expand Down Expand Up @@ -277,31 +226,6 @@ def norm(self, _: Abstract3DVector | None = None, /) -> ct.BatchableSpeed:
return xp.sqrt(self.d_x**2 + self.d_y**2 + self.d_z**2)


@final
class SphericalDifferential(Abstract3DVectorDifferential):
"""Spherical differential representation."""

d_r: ct.BatchableSpeed = eqx.field(
converter=partial(Quantity["speed"].constructor, dtype=float)
)
r"""Radial speed :math:`dr/dt \in [-\infty, \infty]."""

d_theta: ct.BatchableAngularSpeed = eqx.field(
converter=partial(Quantity["angular speed"].constructor, dtype=float)
)
r"""Inclination speed :math:`d\theta/dt \in [-\infty, \infty]."""

d_phi: ct.BatchableAngularSpeed = eqx.field(
converter=partial(Quantity["angular speed"].constructor, dtype=float)
)
r"""Azimuthal speed :math:`d\phi/dt \in [-\infty, \infty]."""

@classproperty
@classmethod
def integral_cls(cls) -> type[SphericalVector]:
return SphericalVector


@final
class CylindricalDifferential(Abstract3DVectorDifferential):
"""Cylindrical differential representation."""
Expand Down
Loading

0 comments on commit 140380f

Please sign in to comment.